diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index ff915e046946..dc3aa102be78 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -7,24 +7,25 @@ on:
env:
DIFFUSERS_IS_CI: yes
- HF_HUB_ENABLE_HF_TRANSFER: 1
+ HF_XET_HIGH_PERFORMANCE: 1
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
+ BASE_PATH: benchmark_outputs
jobs:
- torch_pipelines_cuda_benchmark_tests:
+ torch_models_cuda_benchmark_tests:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
- name: Torch Core Pipelines CUDA Benchmarking Tests
+ name: Torch Core Models CUDA Benchmarking Tests
strategy:
fail-fast: false
max-parallel: 1
runs-on:
- group: aws-g6-4xlarge-plus
+ group: aws-g6e-4xlarge
container:
- image: diffusers/diffusers-pytorch-compile-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -35,27 +36,46 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install pandas peft
- python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
+ apt update
+ apt install -y libpq-dev postgresql-client
+ uv pip install -e ".[quality]"
+ uv pip install -r benchmarks/requirements.txt
- name: Environment
run: |
python utils/print_env.py
- name: Diffusers Benchmarking
env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
- BASE_PATH: benchmark_outputs
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
- cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
+ cd benchmarks && python run_all.py
+
+ - name: Push results to the Hub
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
+ run: |
+ cd benchmarks && python push_results.py
+ mkdir $BASE_PATH && cp *.csv $BASE_PATH
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: benchmark_test_reports
- path: benchmarks/benchmark_outputs
+ path: benchmarks/${{ env.BASE_PATH }}
+
+ # TODO: enable this once the connection problem has been resolved.
+ - name: Update benchmarking results to DB
+ env:
+ PGDATABASE: metrics
+ PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
+ PGUSER: transformers_benchmarks
+ PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
+ run: |
+ git config --global --add safe.directory /__w/diffusers/diffusers
+ commit_id=$GITHUB_SHA
+ commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
+ cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
- name: Report success status
if: ${{ success() }}
diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml
index 340d8a19e17a..b1af44736730 100644
--- a/.github/workflows/build_docker_images.yml
+++ b/.github/workflows/build_docker_images.yml
@@ -38,15 +38,43 @@ jobs:
token: ${{ secrets.GITHUB_TOKEN }}
- name: Build Changed Docker Images
+ env:
+ CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
run: |
- CHANGED_FILES="${{ steps.file_changes.outputs.all }}"
+ echo "$CHANGED_FILES"
+ ALLOWED_IMAGES=(
+ diffusers-pytorch-cpu
+ diffusers-pytorch-cuda
+ diffusers-pytorch-xformers-cuda
+ diffusers-pytorch-minimum-cuda
+ diffusers-doc-builder
+ )
+
+ declare -A IMAGES_TO_BUILD=()
+
for FILE in $CHANGED_FILES; do
- if [[ "$FILE" == docker/*Dockerfile ]]; then
- DOCKER_PATH="${FILE%/Dockerfile}"
- DOCKER_TAG=$(basename "$DOCKER_PATH")
- echo "Building Docker image for $DOCKER_TAG"
- docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
+ # skip anything that isn't still on disk
+ if [[ ! -e "$FILE" ]]; then
+ echo "Skipping removed file $FILE"
+ continue
fi
+
+ for IMAGE in "${ALLOWED_IMAGES[@]}"; do
+ if [[ "$FILE" == docker/${IMAGE}/* ]]; then
+ IMAGES_TO_BUILD["$IMAGE"]=1
+ fi
+ done
+ done
+
+ if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then
+ echo "No relevant Docker changes detected."
+ exit 0
+ fi
+
+ for IMAGE in "${!IMAGES_TO_BUILD[@]}"; do
+ DOCKER_PATH="docker/${IMAGE}"
+ echo "Building Docker image for $IMAGE"
+ docker build -t "$IMAGE" "$DOCKER_PATH"
done
if: steps.file_changes.outputs.all != ''
@@ -65,13 +93,8 @@ jobs:
image-name:
- diffusers-pytorch-cpu
- diffusers-pytorch-cuda
- - diffusers-pytorch-compile-cuda
- diffusers-pytorch-xformers-cuda
- diffusers-pytorch-minimum-cuda
- - diffusers-flax-cpu
- - diffusers-flax-tpu
- - diffusers-onnxruntime-cpu
- - diffusers-onnxruntime-cuda
- diffusers-doc-builder
steps:
diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml
index 52e075733163..f47645c1f659 100644
--- a/.github/workflows/build_pr_documentation.yml
+++ b/.github/workflows/build_pr_documentation.yml
@@ -12,7 +12,33 @@ concurrency:
cancel-in-progress: true
jobs:
+ check-links:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+
+ - name: Install uv
+ run: |
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ echo "$HOME/.cargo/bin" >> $GITHUB_PATH
+
+ - name: Install doc-builder
+ run: |
+ uv pip install --system git+https://github.com/huggingface/doc-builder.git@main
+
+ - name: Check documentation links
+ run: |
+ uv run doc-builder check-links docs/source/en
+
build:
+ needs: check-links
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
diff --git a/.github/workflows/mirror_community_pipeline.yml b/.github/workflows/mirror_community_pipeline.yml
index f6eff1bbd8f0..ab4ded973047 100644
--- a/.github/workflows/mirror_community_pipeline.yml
+++ b/.github/workflows/mirror_community_pipeline.yml
@@ -74,19 +74,19 @@ jobs:
python-version: "3.10"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install --upgrade huggingface_hub
# Check secret is set
- name: whoami
- run: huggingface-cli whoami
+ run: hf auth whoami
env:
HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
# Push to HF! (under subfolder based on checkout ref)
# https://huggingface.co/datasets/diffusers/community-pipelines-mirror
- name: Mirror community pipeline to HF
- run: huggingface-cli upload diffusers/community-pipelines-mirror ./examples/community ${PATH_IN_REPO} --repo-type dataset
+ run: hf upload diffusers/community-pipelines-mirror ./examples/community ${PATH_IN_REPO} --repo-type dataset
env:
PATH_IN_REPO: ${{ env.PATH_IN_REPO }}
HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index 2b39eea2fe5d..8b7e57e91297 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -7,14 +7,15 @@ on:
env:
DIFFUSERS_IS_CI: yes
- HF_HUB_ENABLE_HF_TRANSFER: 1
+ HF_XET_HIGH_PERFORMANCE: 1
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 600
RUN_SLOW: yes
RUN_NIGHTLY: yes
- PIPELINE_USAGE_CUTOFF: 5000
+ PIPELINE_USAGE_CUTOFF: 0
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
+ CONSOLIDATED_REPORT_PATH: consolidated_test_report.md
jobs:
setup_torch_cuda_pipeline_matrix:
@@ -60,7 +61,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -70,10 +71,11 @@ jobs:
run: nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
+ uv pip install -e ".[quality]"
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
+ uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -83,8 +85,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
tests/pipelines/${{ matrix.module }}
@@ -99,11 +101,6 @@ jobs:
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
- - name: Generate Report and Notify Channel
- if: always()
- run: |
- pip install slack_sdk tabulate
- python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_nightly_tests_for_other_torch_modules:
name: Nightly Torch CUDA Tests
@@ -111,7 +108,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
defaults:
run:
shell: bash
@@ -128,11 +125,12 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
+ uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -143,8 +141,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
tests/${{ matrix.module }}
@@ -156,8 +154,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v --make-reports=examples_torch_cuda \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ --make-reports=examples_torch_cuda \
--report-log=examples_torch_cuda.log \
examples/
@@ -174,11 +172,49 @@ jobs:
name: torch_${{ matrix.module }}_cuda_test_reports
path: reports
- - name: Generate Report and Notify Channel
- if: always()
+ run_torch_compile_tests:
+ name: PyTorch Compile CUDA tests
+
+ runs-on:
+ group: aws-g4dn-2xlarge
+
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --gpus all --shm-size "16gb" --ipc host
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: NVIDIA-SMI
run: |
- pip install slack_sdk tabulate
- python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+ nvidia-smi
+ - name: Install dependencies
+ run: |
+ uv pip install -e ".[quality,training]"
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Run torch compile tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ RUN_COMPILE: yes
+ run: |
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_torch_compile_cuda_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_compile_test_reports
+ path: reports
run_big_gpu_torch_tests:
name: Torch tests on big GPU
@@ -189,7 +225,7 @@ jobs:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -199,11 +235,12 @@ jobs:
run: nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
+ uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -214,8 +251,8 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -m "big_gpu_with_torch_cuda" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -m "big_accelerator" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
@@ -230,19 +267,14 @@ jobs:
with:
name: torch_cuda_big_gpu_test_reports
path: reports
- - name: Generate Report and Notify Channel
- if: always()
- run: |
- pip install slack_sdk tabulate
- python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
-
+
torch_minimum_version_cuda_tests:
name: Torch Minimum Version CUDA Tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-minimum-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
defaults:
run:
shell: bash
@@ -254,10 +286,11 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -269,8 +302,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -292,143 +325,34 @@ jobs:
with:
name: torch_minimum_version_cuda_test_reports
path: reports
-
- run_flax_tpu_tests:
- name: Nightly Flax TPU Tests
- runs-on:
- group: gcp-ct5lp-hightpu-8t
- if: github.event_name == 'schedule'
-
- container:
- image: diffusers/diffusers-flax-tpu
- options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
-
- - name: Environment
- run: python utils/print_env.py
-
- - name: Run nightly Flax TPU tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 0 \
- -s -v -k "Flax" \
- --make-reports=tests_flax_tpu \
- --report-log=tests_flax_tpu.log \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_flax_tpu_stats.txt
- cat reports/tests_flax_tpu_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: flax_tpu_test_reports
- path: reports
-
- - name: Generate Report and Notify Channel
- if: always()
- run: |
- pip install slack_sdk tabulate
- python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
-
- run_nightly_onnx_tests:
- name: Nightly ONNXRuntime CUDA tests on Ubuntu
- runs-on:
- group: aws-g4dn-2xlarge
- container:
- image: diffusers/diffusers-onnxruntime-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
-
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: NVIDIA-SMI
- run: nvidia-smi
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
- - name: Environment
- run: python utils/print_env.py
-
- - name: Run Nightly ONNXRuntime CUDA tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Onnx" \
- --make-reports=tests_onnx_cuda \
- --report-log=tests_onnx_cuda.log \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_onnx_cuda_stats.txt
- cat reports/tests_onnx_cuda_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: tests_onnx_cuda_reports
- path: reports
-
- - name: Generate Report and Notify Channel
- if: always()
- run: |
- pip install slack_sdk tabulate
- python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_nightly_quantization_tests:
name: Torch quantization nightly tests
strategy:
fail-fast: false
max-parallel: 2
- matrix:
+ matrix:
config:
- backend: "bitsandbytes"
test_location: "bnb"
additional_deps: ["peft"]
- backend: "gguf"
test_location: "gguf"
- additional_deps: []
+ additional_deps: ["peft", "kernels"]
- backend: "torchao"
test_location: "torchao"
additional_deps: []
- backend: "optimum_quanto"
test_location: "quanto"
additional_deps: []
+ - backend: "nvidia_modelopt"
+ test_location: "modelopt"
+ additional_deps: []
runs-on:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "20gb" --ipc host --gpus 0
+ options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -438,13 +362,14 @@ jobs:
run: nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install -U ${{ matrix.config.backend }}
+ uv pip install -e ".[quality]"
+ uv pip install -U ${{ matrix.config.backend }}
if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
- python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
+ uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
- python -m uv pip install pytest-reportlog
+ uv pip install pytest-reportlog
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -455,7 +380,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.backend }}_torch_cuda \
--report-log=tests_${{ matrix.config.backend }}_torch_cuda.log \
tests/quantization/${{ matrix.config.test_location }}
@@ -470,11 +395,115 @@ jobs:
with:
name: torch_cuda_${{ matrix.config.backend }}_reports
path: reports
- - name: Generate Report and Notify Channel
- if: always()
+
+ run_nightly_pipeline_level_quantization_tests:
+ name: Torch quantization nightly tests
+ strategy:
+ fail-fast: false
+ max-parallel: 2
+ runs-on:
+ group: aws-g6e-xlarge-plus
+ container:
+ image: diffusers/diffusers-pytorch-cuda
+ options: --shm-size "20gb" --ipc host --gpus all
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+ - name: NVIDIA-SMI
+ run: nvidia-smi
+ - name: Install dependencies
+ run: |
+ uv pip install -e ".[quality]"
+ uv pip install -U bitsandbytes optimum_quanto
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
+ uv pip install pytest-reportlog
+ - name: Environment
+ run: |
+ python utils/print_env.py
+ - name: Pipeline-level quantization tests on GPU
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
+ # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
+ CUBLAS_WORKSPACE_CONFIG: :16:8
+ BIG_GPU_MEMORY: 40
+ run: |
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ --make-reports=tests_pipeline_level_quant_torch_cuda \
+ --report-log=tests_pipeline_level_quant_torch_cuda.log \
+ tests/quantization/test_pipeline_level_quantization.py
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: |
+ cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt
+ cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: torch_cuda_pipeline_level_quant_reports
+ path: reports
+
+ generate_consolidated_report:
+ name: Generate Consolidated Test Report
+ needs: [
+ run_nightly_tests_for_torch_pipelines,
+ run_nightly_tests_for_other_torch_modules,
+ run_torch_compile_tests,
+ run_big_gpu_torch_tests,
+ run_nightly_quantization_tests,
+ run_nightly_pipeline_level_quantization_tests,
+ # run_nightly_onnx_tests,
+ torch_minimum_version_cuda_tests,
+ # run_flax_tpu_tests
+ ]
+ if: always()
+ runs-on:
+ group: aws-general-8-plus
+ container:
+ image: diffusers/diffusers-pytorch-cpu
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Create reports directory
+ run: mkdir -p combined_reports
+
+ - name: Download all test reports
+ uses: actions/download-artifact@v4
+ with:
+ path: artifacts
+
+ - name: Prepare reports
+ run: |
+ # Move all report files to a single directory for processing
+ find artifacts -name "*.txt" -exec cp {} combined_reports/ \;
+
+ - name: Install dependencies
run: |
+ pip install -e .[test]
pip install slack_sdk tabulate
- python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
+
+ - name: Generate consolidated report
+ run: |
+ python utils/consolidated_test_report.py \
+ --reports_dir combined_reports \
+ --output_file $CONSOLIDATED_REPORT_PATH \
+ --slack_channel_name diffusers-ci-nightly
+
+ - name: Show consolidated report
+ run: |
+ cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
+
+ - name: Upload consolidated report
+ uses: actions/upload-artifact@v4
+ with:
+ name: consolidated_test_report
+ path: ${{ env.CONSOLIDATED_REPORT_PATH }}
# M1 runner currently not well supported
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
@@ -501,11 +530,11 @@ jobs:
# - name: Install dependencies
# shell: arch -arch arm64 bash {0}
# run: |
-# ${CONDA_RUN} python -m pip install --upgrade pip uv
-# ${CONDA_RUN} python -m uv pip install -e [quality,test]
-# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
-# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
-# ${CONDA_RUN} python -m uv pip install pytest-reportlog
+# ${CONDA_RUN} pip install --upgrade pip uv
+# ${CONDA_RUN} uv pip install -e ".[quality]"
+# ${CONDA_RUN} uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
+# ${CONDA_RUN} uv pip install accelerate@git+https://github.com/huggingface/accelerate
+# ${CONDA_RUN} uv pip install pytest-reportlog
# - name: Environment
# shell: arch -arch arm64 bash {0}
# run: |
@@ -516,7 +545,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
-# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
+# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
@@ -557,11 +586,11 @@ jobs:
# - name: Install dependencies
# shell: arch -arch arm64 bash {0}
# run: |
-# ${CONDA_RUN} python -m pip install --upgrade pip uv
-# ${CONDA_RUN} python -m uv pip install -e [quality,test]
-# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
-# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
-# ${CONDA_RUN} python -m uv pip install pytest-reportlog
+# ${CONDA_RUN} pip install --upgrade pip uv
+# ${CONDA_RUN} uv pip install -e ".[quality]"
+# ${CONDA_RUN} uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
+# ${CONDA_RUN} uv pip install accelerate@git+https://github.com/huggingface/accelerate
+# ${CONDA_RUN} uv pip install pytest-reportlog
# - name: Environment
# shell: arch -arch arm64 bash {0}
# run: |
@@ -572,7 +601,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
-# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
+# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
diff --git a/.github/workflows/pr_dependency_test.yml b/.github/workflows/pr_dependency_test.yml
index d9350c09ac42..b914d1076190 100644
--- a/.github/workflows/pr_dependency_test.yml
+++ b/.github/workflows/pr_dependency_test.yml
@@ -25,11 +25,8 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install --upgrade pip uv
- python -m uv pip install -e .
- python -m uv pip install pytest
+ pip install -e .
+ pip install pytest
- name: Check for soft dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- pytest tests/others/test_dependencies.py
+ pytest tests/others/test_dependencies.py
diff --git a/.github/workflows/pr_flax_dependency_test.yml b/.github/workflows/pr_flax_dependency_test.yml
deleted file mode 100644
index e091b5f2d7b3..000000000000
--- a/.github/workflows/pr_flax_dependency_test.yml
+++ /dev/null
@@ -1,38 +0,0 @@
-name: Run Flax dependency tests
-
-on:
- pull_request:
- branches:
- - main
- paths:
- - "src/diffusers/**.py"
- push:
- branches:
- - main
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
-
-jobs:
- check_flax_dependencies:
- runs-on: ubuntu-22.04
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v4
- with:
- python-version: "3.8"
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install --upgrade pip uv
- python -m uv pip install -e .
- python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
- python -m uv pip install "flax>=0.4.1"
- python -m uv pip install "jaxlib>=0.1.65"
- python -m uv pip install pytest
- - name: Check for soft dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- pytest tests/others/test_dependencies.py
diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml
new file mode 100644
index 000000000000..13c228621f5c
--- /dev/null
+++ b/.github/workflows/pr_modular_tests.yml
@@ -0,0 +1,139 @@
+name: Fast PR tests for Modular
+
+on:
+ pull_request:
+ branches: [main]
+ paths:
+ - "src/diffusers/modular_pipelines/**.py"
+ - "src/diffusers/models/modeling_utils.py"
+ - "src/diffusers/models/model_loading_utils.py"
+ - "src/diffusers/pipelines/pipeline_utils.py"
+ - "src/diffusers/pipeline_loading_utils.py"
+ - "src/diffusers/loaders/lora_base.py"
+ - "src/diffusers/loaders/lora_pipeline.py"
+ - "src/diffusers/loaders/peft.py"
+ - "tests/modular_pipelines/**.py"
+ - ".github/**.yml"
+ - "utils/**.py"
+ - "setup.py"
+ push:
+ branches:
+ - ci-*
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ DIFFUSERS_IS_CI: yes
+ HF_XET_HIGH_PERFORMANCE: 1
+ OMP_NUM_THREADS: 4
+ MKL_NUM_THREADS: 4
+ PYTEST_TIMEOUT: 60
+
+jobs:
+ check_code_quality:
+ runs-on: ubuntu-22.04
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ - name: Install dependencies
+ run: |
+ pip install --upgrade pip
+ pip install .[quality]
+ - name: Check quality
+ run: make quality
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
+
+ check_repository_consistency:
+ needs: check_code_quality
+ runs-on: ubuntu-22.04
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ - name: Install dependencies
+ run: |
+ pip install --upgrade pip
+ pip install .[quality]
+ - name: Check repo consistency
+ run: |
+ python utils/check_copies.py
+ python utils/check_dummies.py
+ python utils/check_support_list.py
+ make deps_table_check_updated
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
+
+ run_fast_tests:
+ needs: [check_code_quality, check_repository_consistency]
+ strategy:
+ fail-fast: false
+ matrix:
+ config:
+ - name: Fast PyTorch Modular Pipeline CPU tests
+ framework: pytorch_pipelines
+ runner: aws-highmemory-32-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_cpu_modular_pipelines
+
+ name: ${{ matrix.config.name }}
+
+ runs-on:
+ group: ${{ matrix.config.runner }}
+
+ container:
+ image: ${{ matrix.config.image }}
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+
+ defaults:
+ run:
+ shell: bash
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ uv pip install -e ".[quality]"
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run fast PyTorch Pipeline CPU tests
+ if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
+ run: |
+ pytest -n 8 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests/modular_pipelines
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
+ path: reports
+
+
diff --git a/.github/workflows/pr_style_bot.yml b/.github/workflows/pr_style_bot.yml
index cf2439c4f2c4..c60004720783 100644
--- a/.github/workflows/pr_style_bot.yml
+++ b/.github/workflows/pr_style_bot.yml
@@ -13,39 +13,5 @@ jobs:
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
with:
python_quality_dependencies: "[quality]"
- pre_commit_script_name: "Download and Compare files from the main branch"
- pre_commit_script: |
- echo "Downloading the files from the main branch"
-
- curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
- curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
- curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
-
- echo "Compare the files and raise error if needed"
-
- diff_failed=0
- if ! diff -q main_Makefile Makefile; then
- echo "Error: The Makefile has changed. Please ensure it matches the main branch."
- diff_failed=1
- fi
-
- if ! diff -q main_setup.py setup.py; then
- echo "Error: The setup.py has changed. Please ensure it matches the main branch."
- diff_failed=1
- fi
-
- if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
- echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
- diff_failed=1
- fi
-
- if [ $diff_failed -eq 1 ]; then
- echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
- exit 1
- fi
-
- echo "No changes in the files. Proceeding..."
- rm -rf main_Makefile main_setup.py main_check_doc_toc.py
- style_command: "make style && make quality"
secrets:
- bot_token: ${{ secrets.GITHUB_TOKEN }}
\ No newline at end of file
+ bot_token: ${{ secrets.HF_STYLE_BOT_ACTION }}
\ No newline at end of file
diff --git a/.github/workflows/pr_test_fetcher.yml b/.github/workflows/pr_test_fetcher.yml
index b032bb842786..83b2ab4edbf6 100644
--- a/.github/workflows/pr_test_fetcher.yml
+++ b/.github/workflows/pr_test_fetcher.yml
@@ -33,8 +33,7 @@ jobs:
fetch-depth: 0
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
python utils/print_env.py
@@ -90,19 +89,16 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install -e [quality,test]
- python -m pip install accelerate
+ uv pip install -e ".[quality]"
+ uv pip install accelerate
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run all selected tests on CPU
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
+ pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
- name: Failure short reports
if: ${{ failure() }}
@@ -148,19 +144,16 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install -e [quality,test]
+ pip install -e [quality]
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- HUGGINGFACE_CO_STAGING=true python -m pytest \
+ HUGGINGFACE_CO_STAGING=true pytest \
-m "is_staging_test" \
--make-reports=tests_${{ matrix.config.report }} \
tests
diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index 10d3cb3248d9..674e62ff443a 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -11,6 +11,7 @@ on:
- "tests/**.py"
- ".github/**.yml"
- "utils/**.py"
+ - "setup.py"
push:
branches:
- ci-*
@@ -21,7 +22,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
- HF_HUB_ENABLE_HF_TRANSFER: 1
+ HF_XET_HIGH_PERFORMANCE: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
@@ -37,7 +38,7 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
@@ -57,7 +58,7 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
@@ -86,11 +87,6 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_models_schedulers
- - name: Fast Flax CPU tests
- framework: flax
- runner: aws-general-8-plus
- image: diffusers/diffusers-flax-cpu
- report: flax_cpu
- name: PyTorch Example CPU tests
framework: pytorch_examples
runner: aws-general-8-plus
@@ -118,49 +114,36 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+ uv pip install -e ".[quality]"
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 8 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/pipelines
- name: Run fast PyTorch Model Scheduler CPU tests
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx and not Dependency" \
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
- - name: Run fast Flax TPU tests
- if: ${{ matrix.config.framework == 'flax' }}
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Flax" \
- --make-reports=tests_${{ matrix.config.report }} \
- tests
-
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install peft timm
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ uv pip install ".[training]"
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
@@ -208,19 +191,16 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- HUGGINGFACE_CO_STAGING=true python -m pytest \
+ HUGGINGFACE_CO_STAGING=true pytest \
-m "is_staging_test" \
--make-reports=tests_${{ matrix.config.report }} \
tests
@@ -262,36 +242,34 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
# TODO (sayakpaul, DN6): revisit `--no-deps`
- python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
- python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
- python -m uv pip install -U tokenizers
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+ uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
+ uv pip install -U tokenizers
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch LoRA tests with PEFT
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v \
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ \
--make-reports=tests_peft_main \
tests/lora/
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v \
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ \
--make-reports=tests_models_lora_peft_main \
tests/models/ -k "lora"
- name: Failure short reports
if: ${{ failure() }}
run: |
- cat reports/tests_lora_failures_short.txt
- cat reports/tests_models_lora_failures_short.txt
+ cat reports/tests_peft_main_failures_short.txt
+ cat reports/tests_models_lora_peft_main_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml
index 87d51773888e..468979d379c1 100644
--- a/.github/workflows/pr_tests_gpu.yml
+++ b/.github/workflows/pr_tests_gpu.yml
@@ -1,4 +1,4 @@
-name: Fast GPU Tests on PR
+name: Fast GPU Tests on PR
on:
pull_request:
@@ -13,6 +13,7 @@ on:
- "src/diffusers/loaders/peft.py"
- "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py"
+ - "examples/**/*.py"
workflow_dispatch:
concurrency:
@@ -23,7 +24,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
- HF_HUB_ENABLE_HF_TRANSFER: 1
+ HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
@@ -38,7 +39,7 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
@@ -58,7 +59,7 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
@@ -70,7 +71,7 @@ jobs:
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
-
+
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
@@ -87,8 +88,7 @@ jobs:
fetch-depth: 2
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
python utils/print_env.py
@@ -117,7 +117,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -129,10 +129,10 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+ uv pip install -e ".[quality]"
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -150,18 +150,18 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- if [ "${{ matrix.module }}" = "ip_adapters" ]; then
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ if [ "${{ matrix.module }}" = "ip_adapters" ]; then
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- else
+ else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx and $pattern" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- fi
+ fi
- name: Failure short reports
if: ${{ failure() }}
@@ -182,13 +182,13 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
defaults:
run:
shell: bash
strategy:
fail-fast: false
- max-parallel: 2
+ max-parallel: 4
matrix:
module: [models, schedulers, lora, others]
steps:
@@ -199,11 +199,11 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -224,11 +224,11 @@ jobs:
run: |
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
if [ -z "$pattern" ]; then
- python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
- --make-reports=tests_torch_cuda_${{ matrix.module }}
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
+ --make-reports=tests_torch_cuda_${{ matrix.module }}
else
- python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
- --make-reports=tests_torch_cuda_${{ matrix.module }}
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
+ --make-reports=tests_torch_cuda_${{ matrix.module }}
fi
- name: Failure short reports
@@ -252,7 +252,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
+ options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -264,22 +264,20 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
- python -m uv pip install -e [quality,test,training]
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install timm
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ uv pip install ".[training]"
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/pr_torch_dependency_test.yml b/.github/workflows/pr_torch_dependency_test.yml
index c39d5eca2d9a..4b6160ff71e2 100644
--- a/.github/workflows/pr_torch_dependency_test.yml
+++ b/.github/workflows/pr_torch_dependency_test.yml
@@ -25,12 +25,8 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install --upgrade pip uv
- python -m uv pip install -e .
- python -m uv pip install torch torchvision torchaudio
- python -m uv pip install pytest
+ pip install -e .
+ pip install torch torchvision torchaudio pytest
- name: Check for soft dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- pytest tests/others/test_dependencies.py
+ pytest tests/others/test_dependencies.py
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index abf825eaa7a0..7b1c441d3dc0 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -14,7 +14,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
- HF_HUB_ENABLE_HF_TRANSFER: 1
+ HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 50000
@@ -34,8 +34,7 @@ jobs:
fetch-depth: 2
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
python utils/print_env.py
@@ -64,7 +63,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -75,9 +74,10 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -87,8 +87,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -109,7 +109,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
defaults:
run:
shell: bash
@@ -126,10 +126,11 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -141,8 +142,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_cuda_${{ matrix.module }} \
tests/${{ matrix.module }}
@@ -159,102 +160,6 @@ jobs:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
- flax_tpu_tests:
- name: Flax TPU Tests
- runs-on:
- group: gcp-ct5lp-hightpu-8t
- container:
- image: diffusers/diffusers-flax-tpu
- options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
-
- - name: Environment
- run: |
- python utils/print_env.py
-
- - name: Run Flax TPU tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 0 \
- -s -v -k "Flax" \
- --make-reports=tests_flax_tpu \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_flax_tpu_stats.txt
- cat reports/tests_flax_tpu_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: flax_tpu_test_reports
- path: reports
-
- onnx_cuda_tests:
- name: ONNX CUDA Tests
- runs-on:
- group: aws-g4dn-2xlarge
- container:
- image: diffusers/diffusers-onnxruntime-cuda
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
-
- - name: Environment
- run: |
- python utils/print_env.py
-
- - name: Run ONNXRuntime CUDA tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Onnx" \
- --make-reports=tests_onnx_cuda \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_onnx_cuda_stats.txt
- cat reports/tests_onnx_cuda_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: onnx_cuda_test_reports
- path: reports
-
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
@@ -262,8 +167,8 @@ jobs:
group: aws-g4dn-2xlarge
container:
- image: diffusers/diffusers-pytorch-compile-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
+ image: diffusers/diffusers-pytorch-cuda
+ options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -276,8 +181,9 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
+ #uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -286,7 +192,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -306,7 +212,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-xformers-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
+ options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -319,8 +225,7 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
python utils/print_env.py
@@ -328,7 +233,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -348,7 +253,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
+ options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -360,21 +265,18 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install timm
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ uv pip install ".[training]"
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml
index e8a73446de73..38cbffaa6315 100644
--- a/.github/workflows/push_tests_fast.yml
+++ b/.github/workflows/push_tests_fast.yml
@@ -18,7 +18,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
- HF_HUB_ENABLE_HF_TRANSFER: 1
+ HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no
@@ -33,16 +33,6 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu
- - name: Fast Flax CPU tests on Ubuntu
- framework: flax
- runner: aws-general-8-plus
- image: diffusers/diffusers-flax-cpu
- report: flax_cpu
- - name: Fast ONNXRuntime CPU tests on Ubuntu
- framework: onnxruntime
- runner: aws-general-8-plus
- image: diffusers/diffusers-onnxruntime-cpu
- report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: aws-general-8-plus
@@ -70,47 +60,25 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
- --make-reports=tests_${{ matrix.config.report }} \
- tests/
-
- - name: Run fast Flax TPU tests
- if: ${{ matrix.config.framework == 'flax' }}
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Flax" \
- --make-reports=tests_${{ matrix.config.report }} \
- tests/
-
- - name: Run fast ONNXRuntime CPU tests
- if: ${{ matrix.config.framework == 'onnxruntime' }}
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Onnx" \
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install peft timm
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ uv pip install ".[training]"
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml
index 5fd3b78be7df..2d6feb592815 100644
--- a/.github/workflows/push_tests_mps.yml
+++ b/.github/workflows/push_tests_mps.yml
@@ -1,19 +1,14 @@
name: Fast mps tests on main
on:
- push:
- branches:
- - main
- paths:
- - "src/diffusers/**.py"
- - "tests/**.py"
+ workflow_dispatch:
env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
- HF_HUB_ENABLE_HF_TRANSFER: 1
+ HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no
@@ -62,7 +57,7 @@ jobs:
HF_HOME: /System/Volumes/Data/mnt/cache
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
- ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
+ ${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml
index 27bd9bd9bb42..efdd6ea2b651 100644
--- a/.github/workflows/release_tests_fast.yml
+++ b/.github/workflows/release_tests_fast.yml
@@ -32,8 +32,7 @@ jobs:
fetch-depth: 2
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
python utils/print_env.py
@@ -62,7 +61,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -73,9 +72,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
@@ -85,8 +83,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -107,7 +105,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
defaults:
run:
shell: bash
@@ -124,10 +122,9 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -139,8 +136,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
tests/${{ matrix.module }}
@@ -163,7 +160,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-minimum-cuda
- options: --shm-size "16gb" --ipc host --gpus 0
+ options: --shm-size "16gb" --ipc host --gpus all
defaults:
run:
shell: bash
@@ -175,10 +172,9 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -190,8 +186,8 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "not Flax and not Onnx" \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -213,101 +209,6 @@ jobs:
with:
name: torch_minimum_version_cuda_test_reports
path: reports
-
- flax_tpu_tests:
- name: Flax TPU Tests
- runs-on: docker-tpu
- container:
- image: diffusers/diffusers-flax-tpu
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
-
- - name: Environment
- run: |
- python utils/print_env.py
-
- - name: Run slow Flax TPU tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 0 \
- -s -v -k "Flax" \
- --make-reports=tests_flax_tpu \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_flax_tpu_stats.txt
- cat reports/tests_flax_tpu_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: flax_tpu_test_reports
- path: reports
-
- onnx_cuda_tests:
- name: ONNX CUDA Tests
- runs-on:
- group: aws-g4dn-2xlarge
- container:
- image: diffusers/diffusers-onnxruntime-cuda
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
-
- - name: Environment
- run: |
- python utils/print_env.py
-
- - name: Run slow ONNXRuntime CUDA tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Onnx" \
- --make-reports=tests_onnx_cuda \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_onnx_cuda_stats.txt
- cat reports/tests_onnx_cuda_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: onnx_cuda_test_reports
- path: reports
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
@@ -316,8 +217,8 @@ jobs:
group: aws-g4dn-2xlarge
container:
- image: diffusers/diffusers-pytorch-compile-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
+ image: diffusers/diffusers-pytorch-cuda
+ options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -330,17 +231,16 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
python utils/print_env.py
- - name: Run example tests on GPU
+ - name: Run torch compile tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -360,7 +260,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-xformers-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
+ options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -373,8 +273,7 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
python utils/print_env.py
@@ -382,7 +281,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -402,7 +301,7 @@ jobs:
container:
image: diffusers/diffusers-pytorch-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
+ options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
@@ -416,21 +315,18 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install timm
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ uv pip install ".[training]"
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/run_tests_from_a_pr.yml b/.github/workflows/run_tests_from_a_pr.yml
index 94fbb2d297c5..fa8c579dd768 100644
--- a/.github/workflows/run_tests_from_a_pr.yml
+++ b/.github/workflows/run_tests_from_a_pr.yml
@@ -30,7 +30,7 @@ jobs:
group: aws-g4dn-2xlarge
container:
image: ${{ github.event.inputs.docker_image }}
- options: --gpus 0 --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
+ options: --gpus all --privileged --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- name: Validate test files input
@@ -63,9 +63,8 @@ jobs:
- name: Install pytest
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft
+ uv pip install -e ".[quality]"
+ uv pip install peft
- name: Run tests
env:
diff --git a/.github/workflows/ssh-runner.yml b/.github/workflows/ssh-runner.yml
index fd65598a53a7..917eb5b1b31a 100644
--- a/.github/workflows/ssh-runner.yml
+++ b/.github/workflows/ssh-runner.yml
@@ -31,7 +31,7 @@ jobs:
group: "${{ github.event.inputs.runner_type }}"
container:
image: ${{ github.event.inputs.docker_image }}
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0 --privileged
+ options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus all --privileged
steps:
- name: Checkout diffusers
diff --git a/.gitignore b/.gitignore
index 15617d5fdc74..a55026febd5a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -125,6 +125,9 @@ dmypy.json
.vs
.vscode
+# Cursor
+.cursor
+
# Pycharm
.idea
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 049d317599ad..ec18df882641 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,4 +1,4 @@
-
# Caching methods
-## Pyramid Attention Broadcast
+Cache methods speedup diffusion transformers by storing and reusing intermediate outputs of specific layers, such as attention and feedforward layers, instead of recalculating them at each inference step.
-[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
+## CacheMixin
-Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
-
-Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
-
-```python
-import torch
-from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
-
-pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
-pipe.to("cuda")
-
-# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
-# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
-# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
-# poorer quality of generated videos.
-config = PyramidAttentionBroadcastConfig(
- spatial_attention_block_skip_range=2,
- spatial_attention_timestep_skip_range=(100, 800),
- current_timestep_callback=lambda: pipe.current_timestep,
-)
-pipe.transformer.enable_cache(config)
-```
-
-## Faster Cache
-
-[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
+[[autodoc]] CacheMixin
-FasterCache is a method that speeds up inference in diffusion transformers by:
-- Reusing attention states between successive inference steps, due to high similarity between them
-- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
+## PyramidAttentionBroadcastConfig
-```python
-import torch
-from diffusers import CogVideoXPipeline, FasterCacheConfig
+[[autodoc]] PyramidAttentionBroadcastConfig
-pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
-pipe.to("cuda")
+[[autodoc]] apply_pyramid_attention_broadcast
-config = FasterCacheConfig(
- spatial_attention_block_skip_range=2,
- spatial_attention_timestep_skip_range=(-1, 681),
- current_timestep_callback=lambda: pipe.current_timestep,
- attention_weight_callback=lambda _: 0.3,
- unconditional_batch_skip_range=5,
- unconditional_batch_timestep_skip_range=(-1, 781),
- tensor_format="BFCHW",
-)
-pipe.transformer.enable_cache(config)
-```
+## FasterCacheConfig
-### CacheMixin
+[[autodoc]] FasterCacheConfig
-[[autodoc]] CacheMixin
+[[autodoc]] apply_faster_cache
-### PyramidAttentionBroadcastConfig
+### FirstBlockCacheConfig
-[[autodoc]] PyramidAttentionBroadcastConfig
+[[autodoc]] FirstBlockCacheConfig
-[[autodoc]] apply_pyramid_attention_broadcast
+[[autodoc]] apply_first_block_cache
-### FasterCacheConfig
+### TaylorSeerCacheConfig
-[[autodoc]] FasterCacheConfig
+[[autodoc]] TaylorSeerCacheConfig
-[[autodoc]] apply_faster_cache
+[[autodoc]] apply_taylorseer_cache
diff --git a/docs/source/en/api/configuration.md b/docs/source/en/api/configuration.md
index 31d70232a95c..328e109e1e4c 100644
--- a/docs/source/en/api/configuration.md
+++ b/docs/source/en/api/configuration.md
@@ -1,4 +1,4 @@
-
+
+# AutoModel
+
+[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file.
+
+## AutoModel
+
+[[autodoc]] AutoModel
+ - all
+ - from_pretrained
diff --git a/docs/source/en/api/models/autoencoder_dc.md b/docs/source/en/api/models/autoencoder_dc.md
index 6f86150eb744..fd53ec0ef66f 100644
--- a/docs/source/en/api/models/autoencoder_dc.md
+++ b/docs/source/en/api/models/autoencoder_dc.md
@@ -1,4 +1,4 @@
-
+
+# AutoencoderKLHunyuanVideo15
+
+The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5) by Tencent.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLHunyuanVideo15
+
+vae = AutoencoderKLHunyuanVideo15.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v", subfolder="vae", torch_dtype=torch.float32)
+
+# make sure to enable tiling to avoid OOM
+vae.enable_tiling()
+```
+
+## AutoencoderKLHunyuanVideo15
+
+[[autodoc]] AutoencoderKLHunyuanVideo15
+ - decode
+ - encode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuanimage.md b/docs/source/en/api/models/autoencoder_kl_hunyuanimage.md
new file mode 100644
index 000000000000..60dd2b3ab155
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_kl_hunyuanimage.md
@@ -0,0 +1,32 @@
+
+
+# AutoencoderKLHunyuanImage
+
+The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLHunyuanImage
+
+vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
+```
+
+## AutoencoderKLHunyuanImage
+
+[[autodoc]] AutoencoderKLHunyuanImage
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuanimage_refiner.md b/docs/source/en/api/models/autoencoder_kl_hunyuanimage_refiner.md
new file mode 100644
index 000000000000..5e1dd5e2a24a
--- /dev/null
+++ b/docs/source/en/api/models/autoencoder_kl_hunyuanimage_refiner.md
@@ -0,0 +1,32 @@
+
+
+# AutoencoderKLHunyuanImageRefiner
+
+The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLHunyuanImageRefiner
+
+vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
+```
+
+## AutoencoderKLHunyuanImageRefiner
+
+[[autodoc]] AutoencoderKLHunyuanImageRefiner
+ - decode
+ - all
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoder_kl_wan.md b/docs/source/en/api/models/autoencoder_kl_wan.md
index 43165c8edf7a..341e5e9a8736 100644
--- a/docs/source/en/api/models/autoencoder_kl_wan.md
+++ b/docs/source/en/api/models/autoencoder_kl_wan.md
@@ -1,4 +1,4 @@
-
+
+# AutoencoderKLCosmos
+
+[Cosmos Tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer).
+
+Supported models:
+- [nvidia/Cosmos-1.0-Tokenizer-CV8x8x8](https://huggingface.co/nvidia/Cosmos-1.0-Tokenizer-CV8x8x8)
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLCosmos
+
+vae = AutoencoderKLCosmos.from_pretrained("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8", subfolder="vae")
+```
+
+## AutoencoderKLCosmos
+
+[[autodoc]] AutoencoderKLCosmos
+ - decode
+ - encode
+ - all
+
+## AutoencoderKLOutput
+
+[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/autoencoderkl_ltx_video.md b/docs/source/en/api/models/autoencoderkl_ltx_video.md
index fbdb11e29cdd..9c2384ca53a1 100644
--- a/docs/source/en/api/models/autoencoderkl_ltx_video.md
+++ b/docs/source/en/api/models/autoencoderkl_ltx_video.md
@@ -1,4 +1,4 @@
-
+
+# AutoencoderKLQwenImage
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import AutoencoderKLQwenImage
+
+vae = AutoencoderKLQwenImage.from_pretrained("Qwen/QwenImage-20B", subfolder="vae")
+```
+
+## AutoencoderKLQwenImage
+
+[[autodoc]] AutoencoderKLQwenImage
+ - decode
+ - encode
+ - all
+
+## AutoencoderKLOutput
+
+[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
+
+## DecoderOutput
+
+[[autodoc]] models.autoencoders.vae.DecoderOutput
diff --git a/docs/source/en/api/models/bria_transformer.md b/docs/source/en/api/models/bria_transformer.md
new file mode 100644
index 000000000000..9df7eeb6ffcd
--- /dev/null
+++ b/docs/source/en/api/models/bria_transformer.md
@@ -0,0 +1,19 @@
+
+
+# BriaTransformer2DModel
+
+A modified flux Transformer model from [Bria](https://huggingface.co/briaai/BRIA-3.2)
+
+## BriaTransformer2DModel
+
+[[autodoc]] BriaTransformer2DModel
diff --git a/docs/source/en/api/models/chroma_transformer.md b/docs/source/en/api/models/chroma_transformer.md
new file mode 100644
index 000000000000..1ef24cda3925
--- /dev/null
+++ b/docs/source/en/api/models/chroma_transformer.md
@@ -0,0 +1,19 @@
+
+
+# ChromaTransformer2DModel
+
+A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
+
+## ChromaTransformer2DModel
+
+[[autodoc]] ChromaTransformer2DModel
diff --git a/docs/source/en/api/models/chronoedit_transformer_3d.md b/docs/source/en/api/models/chronoedit_transformer_3d.md
new file mode 100644
index 000000000000..94982821795d
--- /dev/null
+++ b/docs/source/en/api/models/chronoedit_transformer_3d.md
@@ -0,0 +1,32 @@
+
+
+# ChronoEditTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data from [ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
+
+> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import ChronoEditTransformer3DModel
+
+transformer = ChronoEditTransformer3DModel.from_pretrained("nvidia/ChronoEdit-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## ChronoEditTransformer3DModel
+
+[[autodoc]] ChronoEditTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/cogvideox_transformer3d.md b/docs/source/en/api/models/cogvideox_transformer3d.md
index 30556ef7be3f..5d50e5dca651 100644
--- a/docs/source/en/api/models/cogvideox_transformer3d.md
+++ b/docs/source/en/api/models/cogvideox_transformer3d.md
@@ -1,4 +1,4 @@
-
# ConsisIDTransformer3DModel
-A Diffusion Transformer model for 3D data from [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) was introduced in [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://arxiv.org/pdf/2411.17440) by Peking University & University of Rochester & etc.
+A Diffusion Transformer model for 3D data from [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) was introduced in [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://huggingface.co/papers/2411.17440) by Peking University & University of Rochester & etc.
The model can be loaded with the following code snippet.
diff --git a/docs/source/en/api/models/consistency_decoder_vae.md b/docs/source/en/api/models/consistency_decoder_vae.md
index 94a64820ebb1..fe039df7f9bf 100644
--- a/docs/source/en/api/models/consistency_decoder_vae.md
+++ b/docs/source/en/api/models/consistency_decoder_vae.md
@@ -1,4 +1,4 @@
-
+
+# SanaControlNetModel
+
+The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
+
+The abstract from the paper is:
+
+*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
+
+This model was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
+The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
+
+## SanaControlNetModel
+[[autodoc]] SanaControlNetModel
+
+## SanaControlNetOutput
+[[autodoc]] models.controlnets.controlnet_sana.SanaControlNetOutput
+
diff --git a/docs/source/en/api/models/controlnet_sd3.md b/docs/source/en/api/models/controlnet_sd3.md
index 78564d238eea..f665dde3a007 100644
--- a/docs/source/en/api/models/controlnet_sd3.md
+++ b/docs/source/en/api/models/controlnet_sd3.md
@@ -1,4 +1,4 @@
-
# SparseControlNetModel
-SparseControlNetModel is an implementation of ControlNet for [AnimateDiff](https://arxiv.org/abs/2307.04725).
+SparseControlNetModel is an implementation of ControlNet for [AnimateDiff](https://huggingface.co/papers/2307.04725).
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
-The SparseCtrl version of ControlNet was introduced in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.
+The SparseCtrl version of ControlNet was introduced in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion Models](https://huggingface.co/papers/2311.16933) for achieving controlled generation in text-to-video diffusion models by Yuwei Guo, Ceyuan Yang, Anyi Rao, Maneesh Agrawala, Dahua Lin, and Bo Dai.
The abstract from the paper is:
diff --git a/docs/source/en/api/models/controlnet_union.md b/docs/source/en/api/models/controlnet_union.md
index 9c0d86984549..466718269758 100644
--- a/docs/source/en/api/models/controlnet_union.md
+++ b/docs/source/en/api/models/controlnet_union.md
@@ -1,4 +1,4 @@
-
+
+# CosmosTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import CosmosTransformer3DModel
+
+transformer = CosmosTransformer3DModel.from_pretrained("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## CosmosTransformer3DModel
+
+[[autodoc]] CosmosTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/dit_transformer2d.md b/docs/source/en/api/models/dit_transformer2d.md
index afac62d53cb4..640bd31feeef 100644
--- a/docs/source/en/api/models/dit_transformer2d.md
+++ b/docs/source/en/api/models/dit_transformer2d.md
@@ -1,4 +1,4 @@
-
+
+# Flux2Transformer2DModel
+
+A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).
+
+## Flux2Transformer2DModel
+
+[[autodoc]] Flux2Transformer2DModel
diff --git a/docs/source/en/api/models/flux_transformer.md b/docs/source/en/api/models/flux_transformer.md
index 381593f1bfe6..d1ccb1a242b3 100644
--- a/docs/source/en/api/models/flux_transformer.md
+++ b/docs/source/en/api/models/flux_transformer.md
@@ -1,4 +1,4 @@
-
+
+# HiDreamImageTransformer2DModel
+
+A Transformer model for image-like data from [HiDream-I1](https://huggingface.co/HiDream-ai).
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import HiDreamImageTransformer2DModel
+
+transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## Loading GGUF quantized checkpoints for HiDream-I1
+
+GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`
+
+```python
+import torch
+from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel
+
+ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
+transformer = HiDreamImageTransformer2DModel.from_single_file(
+ ckpt_path,
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ torch_dtype=torch.bfloat16
+)
+```
+
+## HiDreamImageTransformer2DModel
+
+[[autodoc]] HiDreamImageTransformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/hunyuan_transformer2d.md b/docs/source/en/api/models/hunyuan_transformer2d.md
index fe137236d18e..4e2d38f3233a 100644
--- a/docs/source/en/api/models/hunyuan_transformer2d.md
+++ b/docs/source/en/api/models/hunyuan_transformer2d.md
@@ -1,4 +1,4 @@
-
+
+# HunyuanVideo15Transformer3DModel
+
+A Diffusion Transformer model for 3D video-like data used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5).
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import HunyuanVideo15Transformer3DModel
+
+transformer = HunyuanVideo15Transformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v" subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## HunyuanVideo15Transformer3DModel
+
+[[autodoc]] HunyuanVideo15Transformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/hunyuan_video_transformer_3d.md b/docs/source/en/api/models/hunyuan_video_transformer_3d.md
index 522d0eb0479d..77d30e5553bc 100644
--- a/docs/source/en/api/models/hunyuan_video_transformer_3d.md
+++ b/docs/source/en/api/models/hunyuan_video_transformer_3d.md
@@ -1,4 +1,4 @@
-
+
+# HunyuanImageTransformer2DModel
+
+A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import HunyuanImageTransformer2DModel
+
+transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## HunyuanImageTransformer2DModel
+
+[[autodoc]] HunyuanImageTransformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/latte_transformer3d.md b/docs/source/en/api/models/latte_transformer3d.md
index f87926aefc9f..6182f403ea48 100644
--- a/docs/source/en/api/models/latte_transformer3d.md
+++ b/docs/source/en/api/models/latte_transformer3d.md
@@ -1,4 +1,4 @@
-
+
+# OvisImageTransformer2DModel
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import OvisImageTransformer2DModel
+
+transformer = OvisImageTransformer2DModel.from_pretrained("AIDC-AI/Ovis-Image-7B", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## OvisImageTransformer2DModel
+
+[[autodoc]] OvisImageTransformer2DModel
diff --git a/docs/source/en/api/models/pixart_transformer2d.md b/docs/source/en/api/models/pixart_transformer2d.md
index 1d392f4e7c2c..a5a08b611334 100644
--- a/docs/source/en/api/models/pixart_transformer2d.md
+++ b/docs/source/en/api/models/pixart_transformer2d.md
@@ -1,4 +1,4 @@
-
+
+# QwenImageTransformer2DModel
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import QwenImageTransformer2DModel
+
+transformer = QwenImageTransformer2DModel.from_pretrained("Qwen/QwenImage-20B", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## QwenImageTransformer2DModel
+
+[[autodoc]] QwenImageTransformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/sana_transformer2d.md b/docs/source/en/api/models/sana_transformer2d.md
index 269aefd7ff69..e3e5fde3a79e 100644
--- a/docs/source/en/api/models/sana_transformer2d.md
+++ b/docs/source/en/api/models/sana_transformer2d.md
@@ -1,4 +1,4 @@
-
+
+# SanaVideoTransformer3DModel
+
+A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
+
+The abstract from the paper is:
+
+*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import SanaVideoTransformer3DModel
+import torch
+
+transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## SanaVideoTransformer3DModel
+
+[[autodoc]] SanaVideoTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
+
diff --git a/docs/source/en/api/models/sd3_transformer2d.md b/docs/source/en/api/models/sd3_transformer2d.md
index feef87db3a63..f4fc4c65826c 100644
--- a/docs/source/en/api/models/sd3_transformer2d.md
+++ b/docs/source/en/api/models/sd3_transformer2d.md
@@ -1,4 +1,4 @@
-
+
+# SkyReelsV2Transformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/SkyworkAI/SkyReels-V2) by the Skywork AI.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import SkyReelsV2Transformer3DModel
+
+transformer = SkyReelsV2Transformer3DModel.from_pretrained("Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## SkyReelsV2Transformer3DModel
+
+[[autodoc]] SkyReelsV2Transformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/stable_audio_transformer.md b/docs/source/en/api/models/stable_audio_transformer.md
index 396b96c8c710..50a936b43ef4 100644
--- a/docs/source/en/api/models/stable_audio_transformer.md
+++ b/docs/source/en/api/models/stable_audio_transformer.md
@@ -1,4 +1,4 @@
-
+
+# BriaFiboTransformer2DModel
+
+A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
+
+## BriaFiboTransformer2DModel
+
+[[autodoc]] BriaFiboTransformer2DModel
diff --git a/docs/source/en/api/models/transformer_temporal.md b/docs/source/en/api/models/transformer_temporal.md
index 02d075dea3f3..e89afeeffeb3 100644
--- a/docs/source/en/api/models/transformer_temporal.md
+++ b/docs/source/en/api/models/transformer_temporal.md
@@ -1,4 +1,4 @@
-
+
+# WanAnimateTransformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import WanAnimateTransformer3DModel
+
+transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## WanAnimateTransformer3DModel
+
+[[autodoc]] WanAnimateTransformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/models/wan_transformer_3d.md b/docs/source/en/api/models/wan_transformer_3d.md
index 56015c4c07f1..c218166584c6 100644
--- a/docs/source/en/api/models/wan_transformer_3d.md
+++ b/docs/source/en/api/models/wan_transformer_3d.md
@@ -1,4 +1,4 @@
-
+
+# ZImageTransformer2DModel
+
+A Transformer model for image-like data from [Z-Image](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo).
+
+## ZImageTransformer2DModel
+
+[[autodoc]] ZImageTransformer2DModel
\ No newline at end of file
diff --git a/docs/source/en/api/modular_diffusers/guiders.md b/docs/source/en/api/modular_diffusers/guiders.md
new file mode 100644
index 000000000000..a24eb7220749
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/guiders.md
@@ -0,0 +1,39 @@
+# Guiders
+
+Guiders are components in Modular Diffusers that control how the diffusion process is guided during generation. They implement various guidance techniques to improve generation quality and control.
+
+## BaseGuidance
+
+[[autodoc]] diffusers.guiders.guider_utils.BaseGuidance
+
+## ClassifierFreeGuidance
+
+[[autodoc]] diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance
+
+## ClassifierFreeZeroStarGuidance
+
+[[autodoc]] diffusers.guiders.classifier_free_zero_star_guidance.ClassifierFreeZeroStarGuidance
+
+## SkipLayerGuidance
+
+[[autodoc]] diffusers.guiders.skip_layer_guidance.SkipLayerGuidance
+
+## SmoothedEnergyGuidance
+
+[[autodoc]] diffusers.guiders.smoothed_energy_guidance.SmoothedEnergyGuidance
+
+## PerturbedAttentionGuidance
+
+[[autodoc]] diffusers.guiders.perturbed_attention_guidance.PerturbedAttentionGuidance
+
+## AdaptiveProjectedGuidance
+
+[[autodoc]] diffusers.guiders.adaptive_projected_guidance.AdaptiveProjectedGuidance
+
+## AutoGuidance
+
+[[autodoc]] diffusers.guiders.auto_guidance.AutoGuidance
+
+## TangentialClassifierFreeGuidance
+
+[[autodoc]] diffusers.guiders.tangential_classifier_free_guidance.TangentialClassifierFreeGuidance
diff --git a/docs/source/en/api/modular_diffusers/pipeline.md b/docs/source/en/api/modular_diffusers/pipeline.md
new file mode 100644
index 000000000000..f60261ea6672
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/pipeline.md
@@ -0,0 +1,5 @@
+# Pipeline
+
+## ModularPipeline
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ModularPipeline
diff --git a/docs/source/en/api/modular_diffusers/pipeline_blocks.md b/docs/source/en/api/modular_diffusers/pipeline_blocks.md
new file mode 100644
index 000000000000..8ad581e679ac
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/pipeline_blocks.md
@@ -0,0 +1,17 @@
+# Pipeline blocks
+
+## ModularPipelineBlocks
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ModularPipelineBlocks
+
+## SequentialPipelineBlocks
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks
+
+## LoopSequentialPipelineBlocks
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.LoopSequentialPipelineBlocks
+
+## AutoPipelineBlocks
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
\ No newline at end of file
diff --git a/docs/source/en/api/modular_diffusers/pipeline_components.md b/docs/source/en/api/modular_diffusers/pipeline_components.md
new file mode 100644
index 000000000000..2d8e10aef6d8
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/pipeline_components.md
@@ -0,0 +1,17 @@
+# Components and configs
+
+## ComponentSpec
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ComponentSpec
+
+## ConfigSpec
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConfigSpec
+
+## ComponentsManager
+
+[[autodoc]] diffusers.modular_pipelines.components_manager.ComponentsManager
+
+## InsertableDict
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline_utils.InsertableDict
\ No newline at end of file
diff --git a/docs/source/en/api/modular_diffusers/pipeline_states.md b/docs/source/en/api/modular_diffusers/pipeline_states.md
new file mode 100644
index 000000000000..341d18ecb41c
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/pipeline_states.md
@@ -0,0 +1,9 @@
+# Pipeline states
+
+## PipelineState
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.PipelineState
+
+## BlockState
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.BlockState
\ No newline at end of file
diff --git a/docs/source/en/api/normalization.md b/docs/source/en/api/normalization.md
index 05ae92a28dc8..fa703b19871b 100644
--- a/docs/source/en/api/normalization.md
+++ b/docs/source/en/api/normalization.md
@@ -1,4 +1,4 @@
-
+
+# Parallelism
+
+Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times. Refer to the [Distributed inferece](../training/distributed_inference) guide to learn more.
+
+## ParallelConfig
+
+[[autodoc]] ParallelConfig
+
+## ContextParallelConfig
+
+[[autodoc]] ContextParallelConfig
+
+[[autodoc]] hooks.apply_context_parallel
diff --git a/docs/source/en/api/pipelines/allegro.md b/docs/source/en/api/pipelines/allegro.md
index 690f8096a0e4..a981fb1f94f7 100644
--- a/docs/source/en/api/pipelines/allegro.md
+++ b/docs/source/en/api/pipelines/allegro.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# aMUSEd
aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen.
-Amused is a lightweight text to image model based off of the [MUSE](https://arxiv.org/abs/2301.00704) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.
+Amused is a lightweight text to image model based off of the [MUSE](https://huggingface.co/papers/2301.00704) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once.
Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes.
diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md
index ed5ced7dbbc7..f0188f3c36fb 100644
--- a/docs/source/en/api/pipelines/animatediff.md
+++ b/docs/source/en/api/pipelines/animatediff.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Attend-and-Excite
Attend-and-Excite for Stable Diffusion was proposed in [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://attendandexcite.github.io/Attend-and-Excite/) and provides textual attention control over image generation.
@@ -20,11 +23,8 @@ The abstract from the paper is:
You can find additional information about Attend-and-Excite on the [project page](https://attendandexcite.github.io/Attend-and-Excite/), the [original codebase](https://github.com/AttendAndExcite/Attend-and-Excite), or try it out in a [demo](https://huggingface.co/spaces/AttendAndExcite/Attend-and-Excite).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionAttendAndExcitePipeline
diff --git a/docs/source/en/api/pipelines/audioldm.md b/docs/source/en/api/pipelines/audioldm.md
index 02fe2c779eee..c8073a14ef0a 100644
--- a/docs/source/en/api/pipelines/audioldm.md
+++ b/docs/source/en/api/pipelines/audioldm.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# AudioLDM
AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://huggingface.co/papers/2301.12503) by Haohe Liu et al. Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM
@@ -35,11 +38,8 @@ During inference:
* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## AudioLDMPipeline
[[autodoc]] AudioLDMPipeline
diff --git a/docs/source/en/api/pipelines/audioldm2.md b/docs/source/en/api/pipelines/audioldm2.md
index debd2c3433e4..45a9002ea070 100644
--- a/docs/source/en/api/pipelines/audioldm2.md
+++ b/docs/source/en/api/pipelines/audioldm2.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# BLIP-Diffusion
-BLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
+BLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://huggingface.co/papers/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
The abstract from the paper is:
@@ -23,11 +26,8 @@ The original codebase can be found at [salesforce/LAVIS](https://github.com/sale
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## BlipDiffusionPipeline
diff --git a/docs/source/en/api/pipelines/bria_3_2.md b/docs/source/en/api/pipelines/bria_3_2.md
new file mode 100644
index 000000000000..059fa01f9f83
--- /dev/null
+++ b/docs/source/en/api/pipelines/bria_3_2.md
@@ -0,0 +1,44 @@
+
+
+# Bria 3.2
+
+Bria 3.2 is the next-generation commercial-ready text-to-image model. With just 4 billion parameters, it provides exceptional aesthetics and text rendering, evaluated to provide on par results to leading open-source models, and outperforming other licensed models.
+In addition to being built entirely on licensed data, 3.2 provides several advantages for enterprise and commercial use:
+
+- Efficient Compute - the model is X3 smaller than the equivalent models in the market (4B parameters vs 12B parameters other open source models)
+- Architecture Consistency: Same architecture as 3.1—ideal for users looking to upgrade without disruption.
+- Fine-tuning Speedup: 2x faster fine-tuning on L40S and A100.
+
+Original model checkpoints for Bria 3.2 can be found [here](https://huggingface.co/briaai/BRIA-3.2).
+Github repo for Bria 3.2 can be found [here](https://github.com/Bria-AI/BRIA-3.2).
+
+If you want to learn more about the Bria platform, and get free traril access, please visit [bria.ai](https://bria.ai).
+
+
+## Usage
+
+_As the model is gated, before using it with diffusers you first need to go to the [Bria 3.2 Hugging Face page](https://huggingface.co/briaai/BRIA-3.2), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
+
+Use the command below to log in:
+
+```bash
+hf auth login
+```
+
+
+## BriaPipeline
+
+[[autodoc]] BriaPipeline
+ - all
+ - __call__
+
diff --git a/docs/source/en/api/pipelines/bria_fibo.md b/docs/source/en/api/pipelines/bria_fibo.md
new file mode 100644
index 000000000000..96c6b0317e1b
--- /dev/null
+++ b/docs/source/en/api/pipelines/bria_fibo.md
@@ -0,0 +1,45 @@
+
+
+# Bria Fibo
+
+Text-to-image models have mastered imagination - but not control. FIBO changes that.
+
+FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
+
+With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
+
+FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
+you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
+
+> [!NOTE]
+> Avoid using freeform text prompts directly with FIBO because it does not produce the best results.
+
+Refer to the Bria Fibo Hugging Face [page](https://huggingface.co/briaai/FIBO) to learn more.
+
+
+## Usage
+
+_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
+
+Use the command below to log in:
+
+```bash
+hf auth login
+```
+
+
+## BriaFiboPipeline
+
+[[autodoc]] BriaFiboPipeline
+ - all
+ - __call__
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md
new file mode 100644
index 000000000000..cc52ffa09a6d
--- /dev/null
+++ b/docs/source/en/api/pipelines/chroma.md
@@ -0,0 +1,101 @@
+
+
+# Chroma
+
+
+
+
+
+
+Chroma is a text to image generation model based on Flux.
+
+Original model checkpoints for Chroma can be found here:
+* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
+* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
+* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
+
+> [!TIP]
+> Chroma can use all the same optimizations as Flux.
+
+## Inference
+
+```python
+import torch
+from diffusers import ChromaPipeline
+
+pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
+pipe.enable_model_cpu_offload()
+
+prompt = [
+ "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
+]
+negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
+
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ generator=torch.Generator("cpu").manual_seed(433),
+ num_inference_steps=40,
+ guidance_scale=3.0,
+ num_images_per_prompt=1,
+).images[0]
+image.save("chroma.png")
+```
+
+## Loading from a single file
+
+To use updated model checkpoints that are not in the Diffusers format, you can use the `ChromaTransformer2DModel` class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
+
+The following example demonstrates how to run Chroma from a single file.
+
+Then run the following example
+
+```python
+import torch
+from diffusers import ChromaTransformer2DModel, ChromaPipeline
+
+model_id = "lodestones/Chroma1-HD"
+dtype = torch.bfloat16
+
+transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
+
+pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
+pipe.enable_model_cpu_offload()
+
+prompt = [
+ "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
+]
+negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
+
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ generator=torch.Generator("cpu").manual_seed(433),
+ num_inference_steps=40,
+ guidance_scale=3.0,
+).images[0]
+
+image.save("chroma-single-file.png")
+```
+
+## ChromaPipeline
+
+[[autodoc]] ChromaPipeline
+ - all
+ - __call__
+
+## ChromaImg2ImgPipeline
+
+[[autodoc]] ChromaImg2ImgPipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/pipelines/chronoedit.md b/docs/source/en/api/pipelines/chronoedit.md
new file mode 100644
index 000000000000..48e70ab9e55e
--- /dev/null
+++ b/docs/source/en/api/pipelines/chronoedit.md
@@ -0,0 +1,156 @@
+
+
+
+
+# ChronoEdit
+
+[ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
+
+> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
+
+*Recent advances in large generative models have greatly enhanced both image editing and in-context image generation, yet a critical gap remains in ensuring physical consistency, where edited objects must remain coherent. This capability is especially vital for world simulation related tasks. In this paper, we present ChronoEdit, a framework that reframes image editing as a video generation problem. First, ChronoEdit treats the input and edited images as the first and last frames of a video, allowing it to leverage large pretrained video generative models that capture not only object appearance but also the implicit physics of motion and interaction through learned temporal consistency. Second, ChronoEdit introduces a temporal reasoning stage that explicitly performs editing at inference time. Under this setting, target frame is jointly denoised with reasoning tokens to imagine a plausible editing trajectory that constrains the solution space to physically viable transformations. The reasoning tokens are then dropped after a few steps to avoid the high computational cost of rendering a full video. To validate ChronoEdit, we introduce PBench-Edit, a new benchmark of image-prompt pairs for contexts that require physical consistency, and demonstrate that ChronoEdit surpasses state-of-the-art baselines in both visual fidelity and physical plausibility. Project page for code and models: [this https URL](https://research.nvidia.com/labs/toronto-ai/chronoedit).*
+
+The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
+
+
+### Image Editing
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
+from diffusers.utils import export_to_video, load_image
+from transformers import CLIPVisionModel
+from PIL import Image
+
+model_id = "nvidia/ChronoEdit-14B-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+image = load_image(
+ "https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
+)
+max_area = 720 * 1280
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+print("width", width, "height", height)
+image = image.resize((width, height))
+prompt = (
+ "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
+ "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
+)
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=5,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ enable_temporal_reasoning=False,
+ num_temporal_reasoning_steps=0,
+).frames[0]
+Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
+```
+
+Optionally, enable **temporal reasoning** for improved physical consistency:
+```py
+output = pipe(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=29,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+ enable_temporal_reasoning=True,
+ num_temporal_reasoning_steps=50,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
+Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
+```
+
+### Inference with 8-Step Distillation Lora
+
+```py
+import torch
+import numpy as np
+from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
+from diffusers.utils import export_to_video, load_image
+from transformers import CLIPVisionModel
+from PIL import Image
+
+model_id = "nvidia/ChronoEdit-14B-Diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
+lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
+pipe.load_lora_weights(lora_path)
+pipe.fuse_lora(lora_scale=1.0)
+pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
+pipe.to("cuda")
+
+image = load_image(
+ "https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
+)
+max_area = 720 * 1280
+aspect_ratio = image.height / image.width
+mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+print("width", width, "height", height)
+image = image.resize((width, height))
+prompt = (
+ "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
+ "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
+)
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_frames=5,
+ num_inference_steps=8,
+ guidance_scale=1.0,
+ enable_temporal_reasoning=False,
+ num_temporal_reasoning_steps=0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
+Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
+```
+
+## ChronoEditPipeline
+
+[[autodoc]] ChronoEditPipeline
+ - all
+ - __call__
+
+## ChronoEditPipelineOutput
+
+[[autodoc]] pipelines.chronoedit.pipeline_output.ChronoEditPipelineOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md
index 0de40f934548..ec673e0763c5 100644
--- a/docs/source/en/api/pipelines/cogvideox.md
+++ b/docs/source/en/api/pipelines/cogvideox.md
@@ -1,4 +1,4 @@
-
-# CogVideoX
-
-
-
+
-[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
-
-The abstract from the paper is:
-
-*We introduce CogVideoX, a large-scale diffusion transformer model designed for generating videos based on text prompts. To efficently model video data, we propose to levearge a 3D Variational Autoencoder (VAE) to compresses videos along both spatial and temporal dimensions. To improve the text-video alignment, we propose an expert transformer with the expert adaptive LayerNorm to facilitate the deep fusion between the two modalities. By employing a progressive training technique, CogVideoX is adept at producing coherent, long-duration videos characterized by significant motion. In addition, we develop an effectively text-video data processing pipeline that includes various data preprocessing strategies and a video captioning method. It significantly helps enhance the performance of CogVideoX, improving both generation quality and semantic alignment. Results show that CogVideoX demonstrates state-of-the-art performance across both multiple machine metrics and human evaluations. The model weight of CogVideoX-2B is publicly available at https://github.com/THUDM/CogVideo.*
+# CogVideoX
-
+[CogVideoX](https://huggingface.co/papers/2408.06072) is a large diffusion transformer model - available in 2B and 5B parameters - designed to generate longer and more consistent videos from text. This model uses a 3D causal variational autoencoder to more efficiently process video data by reducing sequence length (and associated training compute) and preventing flickering in generated videos. An "expert" transformer with adaptive LayerNorm improves alignment between text and video, and 3D full attention helps accurately capture motion and time in generated videos.
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+You can find all the original CogVideoX checkpoints under the [CogVideoX](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) collection.
-
+> [!TIP]
+> Click on the CogVideoX models in the right sidebar for more examples of other video generation tasks.
-This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
+The example below demonstrates how to generate a video optimized for memory or inference speed.
-There are three official CogVideoX checkpoints for text-to-video and video-to-video.
+
+
-| checkpoints | recommended inference dtype |
-|:---:|:---:|
-| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
-| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
-| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
-There are two official CogVideoX checkpoints available for image-to-video.
+The quantized CogVideoX 5B model below requires ~16GB of VRAM.
-| checkpoints | recommended inference dtype |
-|:---:|:---:|
-| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
-| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |
+```py
+import torch
+from diffusers import CogVideoXPipeline, AutoModel
+from diffusers.quantizers import PipelineQuantizationConfig
+from diffusers.hooks import apply_group_offloading
+from diffusers.utils import export_to_video
-For the CogVideoX 1.5 series:
-- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution.
-- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16.
-- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.
+# quantize weights to int8 with torchao
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="torchao",
+ quant_kwargs={"quant_type": "int8wo"},
+ components_to_quantize="transformer"
+)
-There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
+# fp8 layerwise weight-casting
+transformer = AutoModel.from_pretrained(
+ "THUDM/CogVideoX-5b",
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16
+)
+transformer.enable_layerwise_casting(
+ storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
+)
-| checkpoints | recommended inference dtype |
-|:---:|:---:|
-| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
-| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |
+pipeline = CogVideoXPipeline.from_pretrained(
+ "THUDM/CogVideoX-5b",
+ transformer=transformer,
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+# model-offloading
+pipeline.enable_model_cpu_offload()
+
+prompt = """
+A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea.
+The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse.
+Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood,
+with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
+"""
+
+video = pipeline(
+ prompt=prompt,
+ guidance_scale=6,
+ num_inference_steps=50
+).frames[0]
+export_to_video(video, "output.mp4", fps=8)
+```
-## Inference
+
+
-Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
+[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
-First, load the pipeline:
+The average inference time with torch.compile on a 80GB A100 is 76.27 seconds compared to 96.89 seconds for an uncompiled model.
-```python
+```py
import torch
-from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
-from diffusers.utils import export_to_video,load_image
-pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b").to("cuda") # or "THUDM/CogVideoX-2b"
-```
-
-If you are using the image-to-video pipeline, load it as follows:
+from diffusers import CogVideoXPipeline
+from diffusers.utils import export_to_video
-```python
-pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V").to("cuda")
-```
+pipeline = CogVideoXPipeline.from_pretrained(
+ "THUDM/CogVideoX-2b",
+ torch_dtype=torch.float16
+).to("cuda")
-Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
+# torch.compile
+pipeline.transformer.to(memory_format=torch.channels_last)
+pipeline.transformer = torch.compile(
+ pipeline.transformer, mode="max-autotune", fullgraph=True
+)
-```python
-pipe.transformer.to(memory_format=torch.channels_last)
+prompt = """
+A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea.
+The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse.
+Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood,
+with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting.
+"""
+
+video = pipeline(
+ prompt=prompt,
+ guidance_scale=6,
+ num_inference_steps=50
+).frames[0]
+export_to_video(video, "output.mp4", fps=8)
```
-Compile the components and run inference:
+
+
-```python
-pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
+## Notes
-# CogVideoX works well with long and well-described prompts
-prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
-video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
-```
+- CogVideoX supports LoRAs with [`~loaders.CogVideoXLoraLoaderMixin.load_lora_weights`].
-The [T2V benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
+
+ Show example code
-```
-Without torch.compile(): Average inference time: 96.89 seconds.
-With torch.compile(): Average inference time: 76.27 seconds.
-```
+ ```py
+ import torch
+ from diffusers import CogVideoXPipeline
+ from diffusers.hooks import apply_group_offloading
+ from diffusers.utils import export_to_video
-### Memory optimization
+ pipeline = CogVideoXPipeline.from_pretrained(
+ "THUDM/CogVideoX-5b",
+ torch_dtype=torch.bfloat16
+ )
+ pipeline.to("cuda")
-CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
+ # load LoRA weights
+ pipeline.load_lora_weights("finetrainers/CogVideoX-1.5-crush-smol-v0", adapter_name="crush-lora")
+ pipeline.set_adapters("crush-lora", 0.9)
-- `pipe.enable_model_cpu_offload()`:
- - Without enabling cpu offloading, memory usage is `33 GB`
- - With enabling cpu offloading, memory usage is `19 GB`
-- `pipe.enable_sequential_cpu_offload()`:
- - Similar to `enable_model_cpu_offload` but can significantly reduce memory usage at the cost of slow inference
- - When enabled, memory usage is under `4 GB`
-- `pipe.vae.enable_tiling()`:
- - With enabling cpu offloading and tiling, memory usage is `11 GB`
-- `pipe.vae.enable_slicing()`
+ # model-offloading
+ pipeline.enable_model_cpu_offload()
-## Quantization
+ prompt = """
+ PIKA_CRUSH A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.
+ """
+ negative_prompt = "inconsistent motion, blurry motion, worse quality, degenerate outputs, deformed outputs"
-Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+ video = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=81,
+ height=480,
+ width=768,
+ num_inference_steps=50
+ ).frames[0]
+ export_to_video(video, "output.mp4", fps=16)
+ ```
-Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`CogVideoXPipeline`] for inference with bitsandbytes.
+
-```py
-import torch
-from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, CogVideoXTransformer3DModel, CogVideoXPipeline
-from diffusers.utils import export_to_video
-from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
+- The text-to-video (T2V) checkpoints work best with a resolution of 1360x768 because that was the resolution it was pretrained on.
-quant_config = BitsAndBytesConfig(load_in_8bit=True)
-text_encoder_8bit = T5EncoderModel.from_pretrained(
- "THUDM/CogVideoX-2b",
- subfolder="text_encoder",
- quantization_config=quant_config,
- torch_dtype=torch.float16,
-)
+- The image-to-video (I2V) checkpoints work with multiple resolutions. The width can vary from 768 to 1360, but the height must be 758. Both height and width must be divisible by 16.
-quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
-transformer_8bit = CogVideoXTransformer3DModel.from_pretrained(
- "THUDM/CogVideoX-2b",
- subfolder="transformer",
- quantization_config=quant_config,
- torch_dtype=torch.float16,
-)
+- Both T2V and I2V checkpoints work best with 81 and 161 frames. It is recommended to export the generated video at 16fps.
-pipeline = CogVideoXPipeline.from_pretrained(
- "THUDM/CogVideoX-2b",
- text_encoder=text_encoder_8bit,
- transformer=transformer_8bit,
- torch_dtype=torch.float16,
- device_map="balanced",
-)
-
-prompt = "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
-video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
-export_to_video(video, "ship.mp4", fps=8)
-```
+- Refer to the table below to view memory usage when various memory-saving techniques are enabled.
+ | method | memory usage (enabled) | memory usage (disabled) |
+ |---|---|---|
+ | enable_model_cpu_offload | 19GB | 33GB |
+ | enable_sequential_cpu_offload | <4GB | ~33GB (very slow inference speed) |
+ | enable_tiling | 11GB (with enable_model_cpu_offload) | --- |
+
## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline
diff --git a/docs/source/en/api/pipelines/cogview3.md b/docs/source/en/api/pipelines/cogview3.md
index 277edca4cf33..5ee02e1a7039 100644
--- a/docs/source/en/api/pipelines/cogview3.md
+++ b/docs/source/en/api/pipelines/cogview3.md
@@ -1,4 +1,4 @@
-
+
+# ControlNet
+
+
+
+
+
+ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
+
+With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
+
+The abstract from the paper is:
+
+*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
+
+This pipeline was contributed by [ishan24](https://huggingface.co/ishan24). ❤️
+The original codebase can be found at [NVlabs/Sana](https://github.com/NVlabs/Sana), and you can find official ControlNet checkpoints on [Efficient-Large-Model's](https://huggingface.co/Efficient-Large-Model) Hub profile.
+
+## SanaControlNetPipeline
+[[autodoc]] SanaControlNetPipeline
+ - all
+ - __call__
+
+## SanaPipelineOutput
+[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/controlnet_sd3.md b/docs/source/en/api/pipelines/controlnet_sd3.md
index cee52ef5d76e..8cdada9edf43 100644
--- a/docs/source/en/api/pipelines/controlnet_sd3.md
+++ b/docs/source/en/api/pipelines/controlnet_sd3.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# ControlNet-XS
@@ -28,11 +31,8 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionControlNetXSPipeline
[[autodoc]] StableDiffusionControlNetXSPipeline
diff --git a/docs/source/en/api/pipelines/controlnetxs_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
index 0862a5d79878..7ae0e2a2a178 100644
--- a/docs/source/en/api/pipelines/controlnetxs_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
@@ -24,17 +27,11 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
-
-
-🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
-
-
-
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+> [!WARNING]
+> 🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionXLControlNetXSPipeline
[[autodoc]] StableDiffusionXLControlNetXSPipeline
diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md
new file mode 100644
index 000000000000..fb9453480e74
--- /dev/null
+++ b/docs/source/en/api/pipelines/cosmos.md
@@ -0,0 +1,79 @@
+
+
+# Cosmos
+
+[Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
+
+*Physical AI needs to be trained digitally first. It needs a digital twin of itself, the policy model, and a digital twin of the world, the world model. In this paper, we present the Cosmos World Foundation Model Platform to help developers build customized world models for their Physical AI setups. We position a world foundation model as a general-purpose world model that can be fine-tuned into customized world models for downstream applications. Our platform covers a video curation pipeline, pre-trained world foundation models, examples of post-training of pre-trained world foundation models, and video tokenizers. To help Physical AI builders solve the most critical problems of our society, we make our platform open-source and our models open-weight with permissive licenses available via https://github.com/NVIDIA/Cosmos.*
+
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+## Loading original format checkpoints
+
+Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
+
+```python
+import torch
+from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
+
+model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
+transformer = CosmosTransformer3DModel.from_single_file(
+ "https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
+negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
+
+output = pipe(
+ prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
+).images[0]
+output.save("output.png")
+```
+
+## CosmosTextToWorldPipeline
+
+[[autodoc]] CosmosTextToWorldPipeline
+ - all
+ - __call__
+
+## CosmosVideoToWorldPipeline
+
+[[autodoc]] CosmosVideoToWorldPipeline
+ - all
+ - __call__
+
+## Cosmos2TextToImagePipeline
+
+[[autodoc]] Cosmos2TextToImagePipeline
+ - all
+ - __call__
+
+## Cosmos2VideoToWorldPipeline
+
+[[autodoc]] Cosmos2VideoToWorldPipeline
+ - all
+ - __call__
+
+## CosmosPipelineOutput
+
+[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
+
+## CosmosImagePipelineOutput
+
+[[autodoc]] pipelines.cosmos.pipeline_output.CosmosImagePipelineOutput
diff --git a/docs/source/en/api/pipelines/dance_diffusion.md b/docs/source/en/api/pipelines/dance_diffusion.md
index 9b6e7b66e198..0434f6319592 100644
--- a/docs/source/en/api/pipelines/dance_diffusion.md
+++ b/docs/source/en/api/pipelines/dance_diffusion.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Dance Diffusion
[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) is by Zach Evans.
@@ -17,11 +20,8 @@ specific language governing permissions and limitations under the License.
Dance Diffusion is the first in a suite of generative audio tools for producers and musicians released by [Harmonai](https://github.com/Harmonai-org).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## DanceDiffusionPipeline
[[autodoc]] DanceDiffusionPipeline
diff --git a/docs/source/en/api/pipelines/ddim.md b/docs/source/en/api/pipelines/ddim.md
index 6802da739cd5..3e8cbae4fb60 100644
--- a/docs/source/en/api/pipelines/ddim.md
+++ b/docs/source/en/api/pipelines/ddim.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# DiffEdit
[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://huggingface.co/papers/2210.11427) is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord.
diff --git a/docs/source/en/api/pipelines/dit.md b/docs/source/en/api/pipelines/dit.md
index 2ee45b631c77..16d0c999619d 100644
--- a/docs/source/en/api/pipelines/dit.md
+++ b/docs/source/en/api/pipelines/dit.md
@@ -1,4 +1,4 @@
-
+
+# Flux2
+
+
+
+
+
+
+Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
+
+Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2).
+
+> [!TIP]
+> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
+>
+> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
+
+## Caption upsampling
+
+Flux.2 can potentially generate better better outputs with better prompts. We can "upsample"
+an input prompt by setting the `caption_upsample_temperature` argument in the pipeline call arguments.
+The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L140) recommends this value to be 0.15.
+
+## Flux2Pipeline
+
+[[autodoc]] Flux2Pipeline
+ - all
+ - __call__
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/framepack.md b/docs/source/en/api/pipelines/framepack.md
new file mode 100644
index 000000000000..a25cfe24a4ba
--- /dev/null
+++ b/docs/source/en/api/pipelines/framepack.md
@@ -0,0 +1,206 @@
+
+
+# Framepack
+
+
+
+
+
+[Packing Input Frame Context in Next-Frame Prediction Models for Video Generation](https://huggingface.co/papers/2504.12626) by Lvmin Zhang and Maneesh Agrawala.
+
+*We present a neural network structure, FramePack, to train next-frame (or next-frame-section) prediction models for video generation. The FramePack compresses input frames to make the transformer context length a fixed number regardless of the video length. As a result, we are able to process a large number of frames using video diffusion with computation bottleneck similar to image diffusion. This also makes the training video batch sizes significantly higher (batch sizes become comparable to image diffusion training). We also propose an anti-drifting sampling method that generates frames in inverted temporal order with early-established endpoints to avoid exposure bias (error accumulation over iterations). Finally, we show that existing video diffusion models can be finetuned with FramePack, and their visual quality may be improved because the next-frame prediction supports more balanced diffusion schedulers with less extreme flow shift timesteps.*
+
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+## Available models
+
+| Model name | Description |
+|:---|:---|
+- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | Trained with the "inverted anti-drifting" strategy as described in the paper. Inference requires setting `sampling_type="inverted_anti_drifting"` when running the pipeline. |
+- [`lllyasviel/FramePack_F1_I2V_HY_20250503`](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503) | Trained with a novel anti-drifting strategy but inference is performed in "vanilla" strategy as described in the paper. Inference requires setting `sampling_type="vanilla"` when running the pipeline. |
+
+## Usage
+
+Refer to the pipeline documentation for basic usage examples. The following section contains examples of offloading, different sampling methods, quantization, and more.
+
+### First and last frame to video
+
+The following example shows how to use Framepack with start and end image controls, using the inverted anti-drifiting sampling model.
+
+```python
+import torch
+from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
+from diffusers.utils import export_to_video, load_image
+from transformers import SiglipImageProcessor, SiglipVisionModel
+
+transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
+ "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
+)
+feature_extractor = SiglipImageProcessor.from_pretrained(
+ "lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
+)
+image_encoder = SiglipVisionModel.from_pretrained(
+ "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
+)
+pipe = HunyuanVideoFramepackPipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo",
+ transformer=transformer,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ torch_dtype=torch.float16,
+)
+
+# Enable memory optimizations
+pipe.enable_model_cpu_offload()
+pipe.vae.enable_tiling()
+
+prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
+first_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
+)
+last_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
+)
+output = pipe(
+ image=first_image,
+ last_image=last_image,
+ prompt=prompt,
+ height=512,
+ width=512,
+ num_frames=91,
+ num_inference_steps=30,
+ guidance_scale=9.0,
+ generator=torch.Generator().manual_seed(0),
+ sampling_type="inverted_anti_drifting",
+).frames[0]
+export_to_video(output, "output.mp4", fps=30)
+```
+
+### Vanilla sampling
+
+The following example shows how to use Framepack with the F1 model trained with vanilla sampling but new regulation approach for anti-drifting.
+
+```python
+import torch
+from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
+from diffusers.utils import export_to_video, load_image
+from transformers import SiglipImageProcessor, SiglipVisionModel
+
+transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
+ "lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16
+)
+feature_extractor = SiglipImageProcessor.from_pretrained(
+ "lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
+)
+image_encoder = SiglipVisionModel.from_pretrained(
+ "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
+)
+pipe = HunyuanVideoFramepackPipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo",
+ transformer=transformer,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ torch_dtype=torch.float16,
+)
+
+# Enable memory optimizations
+pipe.enable_model_cpu_offload()
+pipe.vae.enable_tiling()
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
+)
+output = pipe(
+ image=image,
+ prompt="A penguin dancing in the snow",
+ height=832,
+ width=480,
+ num_frames=91,
+ num_inference_steps=30,
+ guidance_scale=9.0,
+ generator=torch.Generator().manual_seed(0),
+ sampling_type="vanilla",
+).frames[0]
+export_to_video(output, "output.mp4", fps=30)
+```
+
+### Group offloading
+
+Group offloading ([`~hooks.apply_group_offloading`]) provides aggressive memory optimizations for offloading internal parts of any model to the CPU, with possibly no additional overhead to generation time. If you have very low VRAM available, this approach may be suitable for you depending on the amount of CPU RAM available.
+
+```python
+import torch
+from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
+from diffusers.hooks import apply_group_offloading
+from diffusers.utils import export_to_video, load_image
+from transformers import SiglipImageProcessor, SiglipVisionModel
+
+transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
+ "lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16
+)
+feature_extractor = SiglipImageProcessor.from_pretrained(
+ "lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
+)
+image_encoder = SiglipVisionModel.from_pretrained(
+ "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
+)
+pipe = HunyuanVideoFramepackPipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo",
+ transformer=transformer,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ torch_dtype=torch.float16,
+)
+
+# Enable group offloading
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+list(map(
+ lambda x: apply_group_offloading(x, onload_device, offload_device, offload_type="leaf_level", use_stream=True, low_cpu_mem_usage=True),
+ [pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]
+))
+pipe.image_encoder.to(onload_device)
+pipe.vae.to(onload_device)
+pipe.vae.enable_tiling()
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
+)
+output = pipe(
+ image=image,
+ prompt="A penguin dancing in the snow",
+ height=832,
+ width=480,
+ num_frames=91,
+ num_inference_steps=30,
+ guidance_scale=9.0,
+ generator=torch.Generator().manual_seed(0),
+ sampling_type="vanilla",
+).frames[0]
+print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
+export_to_video(output, "output.mp4", fps=30)
+```
+
+## HunyuanVideoFramepackPipeline
+
+[[autodoc]] HunyuanVideoFramepackPipeline
+ - all
+ - __call__
+
+## HunyuanVideoPipelineOutput
+
+[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
+
diff --git a/docs/source/en/api/pipelines/hidream.md b/docs/source/en/api/pipelines/hidream.md
new file mode 100644
index 000000000000..add4ad313231
--- /dev/null
+++ b/docs/source/en/api/pipelines/hidream.md
@@ -0,0 +1,40 @@
+
+
+# HiDreamImage
+
+[HiDream-I1](https://huggingface.co/HiDream-ai) by HiDream.ai
+
+> [!TIP]
+> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
+
+## Available models
+
+The following models are available for the [`HiDreamImagePipeline`] pipeline:
+
+| Model name | Description |
+|:---|:---|
+| [`HiDream-ai/HiDream-I1-Full`](https://huggingface.co/HiDream-ai/HiDream-I1-Full) | - |
+| [`HiDream-ai/HiDream-I1-Dev`](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) | - |
+| [`HiDream-ai/HiDream-I1-Fast`](https://huggingface.co/HiDream-ai/HiDream-I1-Fast) | - |
+
+## HiDreamImagePipeline
+
+[[autodoc]] HiDreamImagePipeline
+ - all
+ - __call__
+
+## HiDreamImagePipelineOutput
+
+[[autodoc]] pipelines.hidream_image.pipeline_output.HiDreamImagePipelineOutput
diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md
index 5d068c8b6ef8..cdd81495b621 100644
--- a/docs/source/en/api/pipelines/hunyuan_video.md
+++ b/docs/source/en/api/pipelines/hunyuan_video.md
@@ -1,4 +1,4 @@
-
-# HunyuanVideo
-
-
-
+
-[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.
+# HunyuanVideo
-*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/tencent/HunyuanVideo).*
+[HunyuanVideo](https://huggingface.co/papers/2412.03603) is a 13B parameter diffusion transformer model designed to be competitive with closed-source video foundation models and enable wider community access. This model uses a "dual-stream to single-stream" architecture to separately process the video and text tokens first, before concatenating and feeding them to the transformer to fuse the multimodal information. A pretrained multimodal large language model (MLLM) is used as the encoder because it has better image-text alignment, better image detail description and reasoning, and it can be used as a zero-shot learner if system instructions are added to user prompts. Finally, HunyuanVideo uses a 3D causal variational autoencoder to more efficiently process video data at the original resolution and frame rate.
-
+You can find all the original HunyuanVideo checkpoints under the [Tencent](https://huggingface.co/tencent) organization.
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+> [!TIP]
+> Click on the HunyuanVideo models in the right sidebar for more examples of video generation tasks.
+>
+> The examples below use a checkpoint from [hunyuanvideo-community](https://huggingface.co/hunyuanvideo-community) because the weights are stored in a layout compatible with Diffusers.
-
+The example below demonstrates how to generate a video optimized for memory or inference speed.
-Recommendations for inference:
-- Both text encoders should be in `torch.float16`.
-- Transformer should be in `torch.bfloat16`.
-- VAE should be in `torch.float16`.
-- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
-- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
-- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
+
+
-## Available models
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
-The following models are available for the [`HunyuanVideoPipeline`](text-to-video) pipeline:
+The quantized HunyuanVideo model below requires ~14GB of VRAM.
-| Model name | Description |
-|:---|:---|
-| [`hunyuanvideo-community/HunyuanVideo`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | Official HunyuanVideo (guidance-distilled). Performs best at multiple resolutions and frames. Performs best with `guidance_scale=6.0`, `true_cfg_scale=1.0` and without a negative prompt. |
-| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-T2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-T2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
+```py
+import torch
+from diffusers import AutoModel, HunyuanVideoPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+from diffusers.utils import export_to_video
-The following models are available for the image-to-video pipeline:
+# quantize weights to int4 with bitsandbytes
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={
+ "load_in_4bit": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": torch.bfloat16
+ },
+ components_to_quantize="transformer"
+)
-| Model name | Description |
-|:---|:---|
-| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
-| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
-| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
+pipeline = HunyuanVideoPipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+)
-## Quantization
+# model-offloading and tiling
+pipeline.enable_model_cpu_offload()
+pipeline.vae.enable_tiling()
-Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
+video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
+export_to_video(video, "output.mp4", fps=15)
+```
+
+
+
-Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`HunyuanVideoPipeline`] for inference with bitsandbytes.
+[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
```py
import torch
-from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
+from diffusers import AutoModel, HunyuanVideoPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import export_to_video
-quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
-transformer_8bit = HunyuanVideoTransformer3DModel.from_pretrained(
- "hunyuanvideo-community/HunyuanVideo",
- subfolder="transformer",
- quantization_config=quant_config,
- torch_dtype=torch.bfloat16,
+# quantize weights to int4 with bitsandbytes
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={
+ "load_in_4bit": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": torch.bfloat16
+ },
+ components_to_quantize="transformer"
)
pipeline = HunyuanVideoPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
- transformer=transformer_8bit,
- torch_dtype=torch.float16,
- device_map="balanced",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
)
-prompt = "A cat walks on the grass, realistic style."
+# model-offloading and tiling
+pipeline.enable_model_cpu_offload()
+pipeline.vae.enable_tiling()
+
+# torch.compile
+pipeline.transformer.to(memory_format=torch.channels_last)
+pipeline.transformer = torch.compile(
+ pipeline.transformer, mode="max-autotune", fullgraph=True
+)
+
+prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
-export_to_video(video, "cat.mp4", fps=15)
+export_to_video(video, "output.mp4", fps=15)
```
+
+
+
+## Notes
+
+- HunyuanVideo supports LoRAs with [`~loaders.HunyuanVideoLoraLoaderMixin.load_lora_weights`].
+
+
+ Show example code
+
+ ```py
+ import torch
+ from diffusers import AutoModel, HunyuanVideoPipeline
+ from diffusers.quantizers import PipelineQuantizationConfig
+ from diffusers.utils import export_to_video
+
+ # quantize weights to int4 with bitsandbytes
+ pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={
+ "load_in_4bit": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": torch.bfloat16
+ },
+ components_to_quantize="transformer"
+ )
+
+ pipeline = HunyuanVideoPipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanVideo",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+ )
+
+ # load LoRA weights
+ pipeline.load_lora_weights("https://huggingface.co/lucataco/hunyuan-steamboat-willie-10", adapter_name="steamboat-willie")
+ pipeline.set_adapters("steamboat-willie", 0.9)
+
+ # model-offloading and tiling
+ pipeline.enable_model_cpu_offload()
+ pipeline.vae.enable_tiling()
+
+ # use "In the style of SWR" to trigger the LoRA
+ prompt = """
+ In the style of SWR. A black and white animated scene featuring a fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys.
+ """
+ video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
+ export_to_video(video, "output.mp4", fps=15)
+ ```
+
+
+
+- Refer to the table below for recommended inference values.
+
+ | parameter | recommended value |
+ |---|---|
+ | text encoder dtype | `torch.float16` |
+ | transformer dtype | `torch.bfloat16` |
+ | vae dtype | `torch.float16` |
+ | `num_frames (k)` | 4 * `k` + 1 |
+
+- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images.
+
## HunyuanVideoPipeline
[[autodoc]] HunyuanVideoPipeline
diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md
new file mode 100644
index 000000000000..d77e72bb0f71
--- /dev/null
+++ b/docs/source/en/api/pipelines/hunyuan_video15.md
@@ -0,0 +1,120 @@
+
+
+
+# HunyuanVideo-1.5
+
+HunyuanVideo-1.5 is a lightweight yet powerful video generation model that achieves state-of-the-art visual quality and motion coherence with only 8.3 billion parameters, enabling efficient inference on consumer-grade GPUs. This achievement is built upon several key components, including meticulous data curation, an advanced DiT architecture with selective and sliding tile attention (SSTA), enhanced bilingual understanding through glyph-aware text encoding, progressive pre-training and post-training, and an efficient video super-resolution network. Leveraging these designs, we developed a unified framework capable of high-quality text-to-video and image-to-video generation across multiple durations and resolutions. Extensive experiments demonstrate that this compact and proficient model establishes a new state-of-the-art among open-source models.
+
+You can find all the original HunyuanVideo checkpoints under the [Tencent](https://huggingface.co/tencent) organization.
+
+> [!TIP]
+> Click on the HunyuanVideo models in the right sidebar for more examples of video generation tasks.
+>
+> The examples below use a checkpoint from [hunyuanvideo-community](https://huggingface.co/hunyuanvideo-community) because the weights are stored in a layout compatible with Diffusers.
+
+The example below demonstrates how to generate a video optimized for memory or inference speed.
+
+
+
+
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
+
+
+```py
+import torch
+from diffusers import AutoModel, HunyuanVideo15Pipeline
+from diffusers.utils import export_to_video
+
+
+pipeline = HunyuanVideo15Pipeline.from_pretrained(
+ "HunyuanVideo-1.5-Diffusers-480p_t2v",
+ torch_dtype=torch.bfloat16,
+)
+
+# model-offloading and tiling
+pipeline.enable_model_cpu_offload()
+pipeline.vae.enable_tiling()
+
+prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
+video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0]
+export_to_video(video, "output.mp4", fps=15)
+```
+
+## Notes
+
+- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.
+
+ - **H100/H800:** `_flash_3_hub` or `_flash_3_varlen_hub`
+ - **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen_hub`
+ - **Other GPUs:** `sage_hub`
+
+Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.
+
+
+```py
+pipe.transformer.set_attention_backend("flash_hub") # or your preferred backend
+```
+
+- [`HunyuanVideo15Pipeline`] use guider and does not take `guidance_scale` parameter at runtime.
+
+You can check the default guider configuration using `pipe.guider`:
+
+```py
+>>> pipe.guider
+ClassifierFreeGuidance {
+ "_class_name": "ClassifierFreeGuidance",
+ "_diffusers_version": "0.36.0.dev0",
+ "enabled": true,
+ "guidance_rescale": 0.0,
+ "guidance_scale": 6.0,
+ "start": 0.0,
+ "stop": 1.0,
+ "use_original_formulation": false
+}
+
+State:
+ step: None
+ num_inference_steps: None
+ timestep: None
+ count_prepared: 0
+ enabled: True
+ num_conditions: 2
+```
+
+To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)`
+
+```py
+pipe.guider = pipe.guider.new(guidance_scale=5.0)
+```
+
+Read more on Guider [here](../../modular_diffusers/guiders).
+
+
+
+## HunyuanVideo15Pipeline
+
+[[autodoc]] HunyuanVideo15Pipeline
+ - all
+ - __call__
+
+## HunyuanVideo15ImageToVideoPipeline
+
+[[autodoc]] HunyuanVideo15ImageToVideoPipeline
+ - all
+ - __call__
+
+## HunyuanVideo15PipelineOutput
+
+[[autodoc]] pipelines.hunyuan_video1_5.pipeline_output.HunyuanVideo15PipelineOutput
diff --git a/docs/source/en/api/pipelines/hunyuandit.md b/docs/source/en/api/pipelines/hunyuandit.md
index d593259a09ed..3f4db66c6c94 100644
--- a/docs/source/en/api/pipelines/hunyuandit.md
+++ b/docs/source/en/api/pipelines/hunyuandit.md
@@ -1,4 +1,4 @@
-
+
+# HunyuanImage2.1
+
+
+HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
+
+HunyuanImage-2.1 comes in the following variants:
+
+| model type | model id |
+|:----------:|:--------:|
+| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
+| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
+| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
+
+> [!TIP]
+> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
+
+## HunyuanImage-2.1
+
+HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
+
+```python
+import torch
+from diffusers import HunyuanImagePipeline
+
+pipe = HunyuanImagePipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe = pipe.to("cuda")
+```
+
+You can inspect the `guider` object:
+
+```py
+>>> pipe.guider
+AdaptiveProjectedMixGuidance {
+ "_class_name": "AdaptiveProjectedMixGuidance",
+ "_diffusers_version": "0.36.0.dev0",
+ "adaptive_projected_guidance_momentum": -0.5,
+ "adaptive_projected_guidance_rescale": 10.0,
+ "adaptive_projected_guidance_scale": 10.0,
+ "adaptive_projected_guidance_start_step": 5,
+ "enabled": true,
+ "eta": 0.0,
+ "guidance_rescale": 0.0,
+ "guidance_scale": 3.5,
+ "start": 0.0,
+ "stop": 1.0,
+ "use_original_formulation": false
+}
+
+State:
+ step: None
+ num_inference_steps: None
+ timestep: None
+ count_prepared: 0
+ enabled: True
+ num_conditions: 2
+ momentum_buffer: None
+ is_apg_enabled: False
+ is_cfg_enabled: True
+```
+
+To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
+
+```py
+import torch
+from diffusers import HunyuanImagePipeline
+
+pipe = HunyuanImagePipeline.from_pretrained(
+ "hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe = pipe.to("cuda")
+
+# Update the guider configuration
+pipe.guider = pipe.guider.new(guidance_scale=5.0)
+
+prompt = (
+ "A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
+ "wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
+ "focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
+)
+
+image = pipe(
+ prompt=prompt,
+ num_inference_steps=50,
+ height=2048,
+ width=2048,
+).images[0]
+image.save("image.png")
+```
+
+
+## HunyuanImage-2.1-Distilled
+
+use `distilled_guidance_scale` with the guidance-distilled checkpoint,
+
+```py
+import torch
+from diffusers import HunyuanImagePipeline
+pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
+pipe = pipe.to("cuda")
+
+prompt = (
+ "A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
+ "wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
+ "focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
+)
+
+out = pipe(
+ prompt,
+ num_inference_steps=8,
+ distilled_guidance_scale=3.25,
+ height=2048,
+ width=2048,
+ generator=generator,
+).images[0]
+
+```
+
+
+## HunyuanImagePipeline
+
+[[autodoc]] HunyuanImagePipeline
+ - all
+ - __call__
+
+## HunyuanImageRefinerPipeline
+
+[[autodoc]] HunyuanImageRefinerPipeline
+ - all
+ - __call__
+
+
+## HunyuanImagePipelineOutput
+
+[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/i2vgenxl.md b/docs/source/en/api/pipelines/i2vgenxl.md
index 3994f91d2cd0..711a5625f99c 100644
--- a/docs/source/en/api/pipelines/i2vgenxl.md
+++ b/docs/source/en/api/pipelines/i2vgenxl.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# I2VGen-XL
[I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models](https://hf.co/papers/2311.04145.pdf) by Shiwei Zhang, Jiayu Wang, Yingya Zhang, Kang Zhao, Hangjie Yuan, Zhiwu Qin, Xiang Wang, Deli Zhao, and Jingren Zhou.
@@ -20,11 +23,8 @@ The abstract from the paper is:
The original codebase can be found [here](https://github.com/ali-vilab/i2vgen-xl/). The model checkpoints can be found [here](https://huggingface.co/ali-vilab/).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
Sample output with I2VGenXL:
@@ -47,7 +47,7 @@ Sample output with I2VGenXL:
* Unlike SVD, it additionally accepts text prompts as inputs.
* It can generate higher resolution videos.
* When using the [`DDIMScheduler`] (which is default for this pipeline), less than 50 steps for inference leads to bad results.
-* This implementation is 1-stage variant of I2VGenXL. The main figure in the [I2VGen-XL](https://arxiv.org/abs/2311.04145) paper shows a 2-stage variant, however, 1-stage variant works well. See [this discussion](https://github.com/huggingface/diffusers/discussions/7952) for more details.
+* This implementation is 1-stage variant of I2VGenXL. The main figure in the [I2VGen-XL](https://huggingface.co/papers/2311.04145) paper shows a 2-stage variant, however, 1-stage variant works well. See [this discussion](https://github.com/huggingface/diffusers/discussions/7952) for more details.
## I2VGenXLPipeline
[[autodoc]] I2VGenXLPipeline
diff --git a/docs/source/en/api/pipelines/kandinsky.md b/docs/source/en/api/pipelines/kandinsky.md
index 72cbf3fb474d..7717f2db69a5 100644
--- a/docs/source/en/api/pipelines/kandinsky.md
+++ b/docs/source/en/api/pipelines/kandinsky.md
@@ -1,4 +1,4 @@
-
+
+# Kandinsky 5.0 Image
+
+[Kandinsky 5.0](https://arxiv.org/abs/2511.14993) is a family of diffusion models for Video & Image generation.
+
+Kandinsky 5.0 Image Lite is a lightweight image generation model (6B parameters).
+
+The model introduces several key innovations:
+- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
+- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
+- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
+- **Flux VAE** for efficient image encoding and decoding
+
+The original codebase can be found at [kandinskylab/Kandinsky-5](https://github.com/kandinskylab/Kandinsky-5).
+
+> [!TIP]
+> Check out the [Kandinsky Lab](https://huggingface.co/kandinskylab) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
+
+
+## Available Models
+
+Kandinsky 5.0 Image Lite:
+
+| model_id | Description | Use Cases |
+|------------|-------------|-----------|
+| [**kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers) | 6B image Supervised Fine-Tuned model | Highest generation quality |
+| [**kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers) | 6B image editing Supervised Fine-Tuned model | Highest generation quality |
+| [**kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers) | 6B image Base pretrained model | Research and fine-tuning |
+| [**kandinskylab/Kandinsky-5.0-I2I-Lite-pretrain-Diffusers**](https://huggingface.co/kandinskylab/Kandinsky-5.0-I2I-Lite-pretrain-Diffusers) | 6B image editing Base pretrained model | Research and fine-tuning |
+
+## Usage Examples
+
+### Basic Text-to-Image Generation
+
+```python
+import torch
+from diffusers import Kandinsky5T2IPipeline
+
+# Load the pipeline
+model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers"
+pipe = Kandinsky5T2IPipeline.from_pretrained(model_id)
+_ = pipe.to(device='cuda',dtype=torch.bfloat16)
+
+# Generate image
+prompt = "A fluffy, expressive cat wearing a bright red hat with a soft, slightly textured fabric. The hat should look cozy and well-fitted on the cat’s head. On the front of the hat, add clean, bold white text that reads “SWEET”, clearly visible and neatly centered. Ensure the overall lighting highlights the hat’s color and the cat’s fur details."
+
+output = pipe(
+ prompt=prompt,
+ negative_prompt="",
+ height=1024,
+ width=1024,
+ num_inference_steps=50,
+ guidance_scale=3.5,
+).image[0]
+```
+
+### Basic Image-to-Image Generation
+
+```python
+import torch
+from diffusers import Kandinsky5I2IPipeline
+from diffusers.utils import load_image
+# Load the pipeline
+model_id = "kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers"
+pipe = Kandinsky5I2IPipeline.from_pretrained(model_id)
+
+_ = pipe.to(device='cuda',dtype=torch.bfloat16)
+pipe.enable_model_cpu_offload() # <--- Enable CPU offloading for single GPU inference
+
+# Edit the input image
+image = load_image(
+ "https://huggingface.co/kandinsky-community/kandinsky-3/resolve/main/assets/title.jpg?download=true"
+)
+
+prompt = "Change the background from a winter night scene to a bright summer day. Place the character on a sandy beach with clear blue sky, soft sunlight, and gentle waves in the distance. Replace the winter clothing with a light short-sleeved T-shirt (in soft pastel colors) and casual shorts. Ensure the character’s fur reflects warm daylight instead of cold winter tones. Add small beach details such as seashells, footprints in the sand, and a few scattered beach toys nearby. Keep the oranges in the scene, but place them naturally on the sand."
+negative_prompt = ""
+
+output = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ guidance_scale=3.5,
+).image[0]
+```
+
+
+## Kandinsky5T2IPipeline
+
+[[autodoc]] Kandinsky5T2IPipeline
+ - all
+ - __call__
+
+## Kandinsky5I2IPipeline
+
+[[autodoc]] Kandinsky5I2IPipeline
+ - all
+ - __call__
+
+
+## Citation
+```bibtex
+@misc{kandinsky2025,
+ author = {Alexander Belykh and Alexander Varlamov and Alexey Letunovskiy and Anastasia Aliaskina and Anastasia Maltseva and Anastasiia Kargapoltseva and Andrey Shutkin and Anna Averchenkova and Anna Dmitrienko and Bulat Akhmatov and Denis Dimitrov and Denis Koposov and Denis Parkhomenko and Dmitrii and Ilya Vasiliev and Ivan Kirillov and Julia Agafonova and Kirill Chernyshev and Kormilitsyn Semen and Lev Novitskiy and Maria Kovaleva and Mikhail Mamaev and Mikhailov and Nikita Kiselev and Nikita Osterov and Nikolai Gerasimenko and Nikolai Vaulin and Olga Kim and Olga Vdovchenko and Polina Gavrilova and Polina Mikhailova and Tatiana Nikulina and Viacheslav Vasilev and Vladimir Arkhipkin and Vladimir Korviakov and Vladimir Polovnikov and Yury Kolabushin},
+ title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
+ howpublished = {\url{https://github.com/kandinskylab/Kandinsky-5}},
+ year = 2025
+}
+```
diff --git a/docs/source/en/api/pipelines/kandinsky5_video.md b/docs/source/en/api/pipelines/kandinsky5_video.md
new file mode 100644
index 000000000000..733e2481732a
--- /dev/null
+++ b/docs/source/en/api/pipelines/kandinsky5_video.md
@@ -0,0 +1,310 @@
+
+
+# Kandinsky 5.0 Video
+
+[Kandinsky 5.0](https://arxiv.org/abs/2511.14993) is a family of diffusion models for Video & Image generation.
+
+Kandinsky 5.0 Lite line-up of lightweight video generation models (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
+
+Kandinsky 5.0 Pro line-up of large high quality video generation models (19B parameters). It offers high qualty generation in HD and more generation formats like I2V.
+
+The model introduces several key innovations:
+- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
+- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
+- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
+- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
+- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
+
+The original codebase can be found at [kandinskylab/Kandinsky-5](https://github.com/kandinskylab/Kandinsky-5).
+
+> [!TIP]
+> Check out the [Kandinsky Lab](https://huggingface.co/kandinskylab) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
+
+## Available Models
+
+Kandinsky 5.0 T2V Pro:
+
+| model_id | Description | Use Cases |
+|------------|-------------|-----------|
+| **kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers** | 5 second Text-to-Video Pro model | High-quality text-to-video generation |
+| **kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers** | 5 second Image-to-Video Pro model | High-quality image-to-video generation |
+
+Kandinsky 5.0 T2V Lite:
+| model_id | Description | Use Cases |
+|------------|-------------|-----------|
+| **kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
+| **kandinskylab/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
+| **kandinskylab/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
+| **kandinskylab/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
+| **kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
+| **kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
+| **kandinskylab/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
+| **kandinskylab/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
+
+
+## Usage Examples
+
+### Basic Text-to-Video Generation
+
+#### Pro
+**⚠️ Warning!** all Pro models should be infered with pipeline.enable_model_cpu_offload()
+```python
+import torch
+from diffusers import Kandinsky5T2VPipeline
+from diffusers.utils import export_to_video
+
+# Load the pipeline
+model_id = "kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers"
+pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+
+pipe = pipe.to("cuda")
+pipeline.transformer.set_attention_backend("flex") # <--- Set attention bakend to Flex
+pipeline.enable_model_cpu_offload() # <--- Enable cpu offloading for single GPU inference
+pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=True) # <--- Compile with max-autotune-no-cudagraphs
+
+# Generate video
+prompt = "A cat and a dog baking a cake together in a kitchen."
+negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+output = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=768,
+ width=1024,
+ num_frames=121, # ~5 seconds at 24fps
+ num_inference_steps=50,
+ guidance_scale=5.0,
+).frames[0]
+
+export_to_video(output, "output.mp4", fps=24, quality=9)
+```
+
+#### Lite
+```python
+import torch
+from diffusers import Kandinsky5T2VPipeline
+from diffusers.utils import export_to_video
+
+# Load the pipeline
+model_id = "kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
+pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+pipe = pipe.to("cuda")
+
+# Generate video
+prompt = "A cat and a dog baking a cake together in a kitchen."
+negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+output = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=512,
+ width=768,
+ num_frames=121, # ~5 seconds at 24fps
+ num_inference_steps=50,
+ guidance_scale=5.0,
+).frames[0]
+
+export_to_video(output, "output.mp4", fps=24, quality=9)
+```
+
+### 10 second Models
+**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
+
+```python
+pipe = Kandinsky5T2VPipeline.from_pretrained(
+ "kandinskylab/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
+ torch_dtype=torch.bfloat16
+)
+pipe = pipe.to("cuda")
+
+pipe.transformer.set_attention_backend(
+ "flex"
+) # <--- Set attention bakend to Flex
+pipe.transformer.compile(
+ mode="max-autotune-no-cudagraphs",
+ dynamic=True
+) # <--- Compile with max-autotune-no-cudagraphs
+
+prompt = "A cat and a dog baking a cake together in a kitchen."
+negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+output = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=512,
+ width=768,
+ num_frames=241,
+ num_inference_steps=50,
+ guidance_scale=5.0,
+).frames[0]
+
+export_to_video(output, "output.mp4", fps=24, quality=9)
+```
+
+### Diffusion Distilled model
+**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
+
+```python
+model_id = "kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
+pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+pipe = pipe.to("cuda")
+
+output = pipe(
+ prompt="A beautiful sunset over mountains",
+ num_inference_steps=16, # <--- Model is distilled in 16 steps
+ guidance_scale=1.0, # <--- no CFG
+).frames[0]
+
+export_to_video(output, "output.mp4", fps=24, quality=9)
+```
+
+
+### Basic Image-to-Video Generation
+**⚠️ Warning!** all Pro models should be infered with pipeline.enable_model_cpu_offload()
+```python
+import torch
+from diffusers import Kandinsky5T2VPipeline
+from diffusers.utils import export_to_video
+
+# Load the pipeline
+model_id = "kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers"
+pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+
+pipe = pipe.to("cuda")
+pipeline.transformer.set_attention_backend("flex") # <--- Set attention bakend to Flex
+pipeline.enable_model_cpu_offload() # <--- Enable cpu offloading for single GPU inference
+pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=True) # <--- Compile with max-autotune-no-cudagraphs
+
+# Generate video
+image = load_image(
+ "https://huggingface.co/kandinsky-community/kandinsky-3/resolve/main/assets/title.jpg?download=true"
+)
+height = 896
+width = 896
+image = image.resize((width, height))
+
+prompt = "An funny furry creture smiles happily and holds a sign that says 'Kandinsky'"
+negative_prompt = ""
+
+output = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ num_frames=121, # ~5 seconds at 24fps
+ num_inference_steps=50,
+ guidance_scale=5.0,
+).frames[0]
+
+export_to_video(output, "output.mp4", fps=24, quality=9)
+```
+
+
+
+## Kandinsky 5.0 Pro Side-by-Side evaluation
+
+
+
+
+
+
+
+
+
+
+
+
+ Comparison with Veo 3
+
+
+ Comparison with Veo 3 fast
+
+
+
+
+
+
+
+
+
+
+ Comparison with Wan 2.2 A14B Text-to-Video mode
+
+
+ Comparison with Wan 2.2 A14B Image-to-Video mode
+
+
+
+
+
+## Kandinsky 5.0 Lite Side-by-Side evaluation
+
+The evaluation is based on the expanded prompts from the [Movie Gen benchmark](https://github.com/facebookresearch/MovieGenBench), which are available in the expanded_prompt column of the benchmark/moviegen_bench.csv file.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+## Kandinsky 5.0 Lite Distill Side-by-Side evaluation
+
+
+
+
+
+
+
+
+
+
+
+
+## Kandinsky5T2VPipeline
+
+[[autodoc]] Kandinsky5T2VPipeline
+ - all
+ - __call__
+
+## Kandinsky5I2VPipeline
+
+[[autodoc]] Kandinsky5I2VPipeline
+ - all
+ - __call__
+
+
+## Citation
+```bibtex
+@misc{kandinsky2025,
+ author = {Alexander Belykh and Alexander Varlamov and Alexey Letunovskiy and Anastasia Aliaskina and Anastasia Maltseva and Anastasiia Kargapoltseva and Andrey Shutkin and Anna Averchenkova and Anna Dmitrienko and Bulat Akhmatov and Denis Dimitrov and Denis Koposov and Denis Parkhomenko and Dmitrii and Ilya Vasiliev and Ivan Kirillov and Julia Agafonova and Kirill Chernyshev and Kormilitsyn Semen and Lev Novitskiy and Maria Kovaleva and Mikhail Mamaev and Mikhailov and Nikita Kiselev and Nikita Osterov and Nikolai Gerasimenko and Nikolai Vaulin and Olga Kim and Olga Vdovchenko and Polina Gavrilova and Polina Mikhailova and Tatiana Nikulina and Viacheslav Vasilev and Vladimir Arkhipkin and Vladimir Korviakov and Vladimir Polovnikov and Yury Kolabushin},
+ title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
+ howpublished = {\url{https://github.com/kandinskylab/Kandinsky-5}},
+ year = 2025
+}
+```
diff --git a/docs/source/en/api/pipelines/kandinsky_v22.md b/docs/source/en/api/pipelines/kandinsky_v22.md
index f097a085ef7f..0e0ed80db61c 100644
--- a/docs/source/en/api/pipelines/kandinsky_v22.md
+++ b/docs/source/en/api/pipelines/kandinsky_v22.md
@@ -1,4 +1,4 @@
-
-# LTX Video
-
-
-
+
-[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
-
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+# LTX-Video
-Available models:
+[LTX-Video](https://huggingface.co/Lightricks/LTX-Video) is a diffusion transformer designed for fast and real-time generation of high-resolution videos from text and images. The main feature of LTX-Video is the Video-VAE. The Video-VAE has a higher pixel to latent compression ratio (1:192) which enables more efficient video data processing and faster generation speed. To support and prevent finer details from being lost during generation, the Video-VAE decoder performs the latent to pixel conversion *and* the last denoising step.
-| Model name | Recommended dtype |
-|:-------------:|:-----------------:|
-| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
-| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
+You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
-Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
+> [!TIP]
+> Click on the LTX-Video models in the right sidebar for more examples of other video generation tasks.
-## Loading Single Files
+The example below demonstrates how to generate a video optimized for memory or inference speed.
-Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
+
+
-```python
-import torch
-from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
-# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors
-single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
-transformer = LTXVideoTransformer3DModel.from_single_file(
- single_file_url, torch_dtype=torch.bfloat16
-)
-vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
-pipe = LTXImageToVideoPipeline.from_pretrained(
- "Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16
-)
-
-# ... inference code ...
-```
-
-Alternatively, the pipeline can be used to load the weights with [`~FromSingleFileMixin.from_single_file`].
-
-```python
-import torch
-from diffusers import LTXImageToVideoPipeline
-from transformers import T5EncoderModel, T5Tokenizer
-
-single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
-text_encoder = T5EncoderModel.from_pretrained(
- "Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16
-)
-tokenizer = T5Tokenizer.from_pretrained(
- "Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16
-)
-pipe = LTXImageToVideoPipeline.from_single_file(
- single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16
-)
-```
-
-Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
+The LTX-Video model below requires ~10GB of VRAM.
```py
import torch
+from diffusers import LTXPipeline, AutoModel
+from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
-from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
-ckpt_path = (
- "https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
-)
-transformer = LTXVideoTransformer3DModel.from_single_file(
- ckpt_path,
- quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
- torch_dtype=torch.bfloat16,
-)
-pipe = LTXPipeline.from_pretrained(
+# fp8 layerwise weight-casting
+transformer = AutoModel.from_pretrained(
"Lightricks/LTX-Video",
- transformer=transformer,
- torch_dtype=torch.bfloat16,
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16
+)
+transformer.enable_layerwise_casting(
+ storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
)
-pipe.enable_model_cpu_offload()
-prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
+pipeline = LTXPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, torch_dtype=torch.bfloat16)
+
+# group-offloading
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
+apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
+apply_group_offloading(pipeline.vae, onload_device=onload_device, offload_type="leaf_level")
+
+prompt = """
+A woman with long brown hair and light skin smiles at another woman with long blonde hair.
+The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek.
+The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and
+natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage
+"""
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
-video = pipe(
+video = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
- width=704,
- height=480,
+ width=768,
+ height=512,
num_frames=161,
+ decode_timestep=0.03,
+ decode_noise_scale=0.025,
num_inference_steps=50,
).frames[0]
-export_to_video(video, "output_gguf_ltx.mp4", fps=24)
+export_to_video(video, "output.mp4", fps=24)
```
-Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
-
-
+
+
-Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights.
+[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
-```python
+```py
import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video
-pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16)
-pipe.to("cuda")
+pipeline = LTXPipeline.from_pretrained(
+ "Lightricks/LTX-Video", torch_dtype=torch.bfloat16
+)
+
+# torch.compile
+pipeline.transformer.to(memory_format=torch.channels_last)
+pipeline.transformer = torch.compile(
+ pipeline.transformer, mode="max-autotune", fullgraph=True
+)
-prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
+prompt = """
+A woman with long brown hair and light skin smiles at another woman with long blonde hair.
+The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek.
+The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and
+natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage
+"""
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
-video = pipe(
+video = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
width=768,
@@ -141,48 +126,353 @@ video = pipe(
export_to_video(video, "output.mp4", fps=24)
```
-Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
-
-## Quantization
-
-Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
-
-Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`LTXPipeline`] for inference with bitsandbytes.
-
-```py
-import torch
-from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, LTXVideoTransformer3DModel, LTXPipeline
-from diffusers.utils import export_to_video
-from transformers import BitsAndBytesConfig as BitsAndBytesConfig, T5EncoderModel
-
-quant_config = BitsAndBytesConfig(load_in_8bit=True)
-text_encoder_8bit = T5EncoderModel.from_pretrained(
- "Lightricks/LTX-Video",
- subfolder="text_encoder",
- quantization_config=quant_config,
- torch_dtype=torch.float16,
-)
-
-quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
-transformer_8bit = LTXVideoTransformer3DModel.from_pretrained(
- "Lightricks/LTX-Video",
- subfolder="transformer",
- quantization_config=quant_config,
- torch_dtype=torch.float16,
-)
-
-pipeline = LTXPipeline.from_pretrained(
- "Lightricks/LTX-Video",
- text_encoder=text_encoder_8bit,
- transformer=transformer_8bit,
- torch_dtype=torch.float16,
- device_map="balanced",
-)
-
-prompt = "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting."
-video = pipeline(prompt=prompt, num_frames=161, num_inference_steps=50).frames[0]
-export_to_video(video, "ship.mp4", fps=24)
-```
+
+
+
+## Notes
+
+- Refer to the following recommended settings for generation from the [LTX-Video](https://github.com/Lightricks/LTX-Video) repository.
+
+ - The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
+ - For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
+ - For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
+ - For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
+
+- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
+
+
+ Show example code
+
+ ```py
+ import torch
+ from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
+ from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
+ from diffusers.utils import export_to_video, load_video
+
+ pipeline = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-dev", torch_dtype=torch.bfloat16)
+ pipeline_upsample = LTXLatentUpsamplePipeline.from_pretrained("Lightricks/ltxv-spatial-upscaler-0.9.7", vae=pipeline.vae, torch_dtype=torch.bfloat16)
+ pipeline.to("cuda")
+ pipe_upsample.to("cuda")
+ pipeline.vae.enable_tiling()
+
+ def round_to_nearest_resolution_acceptable_by_vae(height, width):
+ height = height - (height % pipeline.vae_temporal_compression_ratio)
+ width = width - (width % pipeline.vae_temporal_compression_ratio)
+ return height, width
+
+ video = load_video(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
+ )[:21] # only use the first 21 frames as conditioning
+ condition1 = LTXVideoCondition(video=video, frame_index=0)
+
+ prompt = """
+ The video depicts a winding mountain road covered in snow, with a single vehicle
+ traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation.
+ The landscape is characterized by rugged terrain and a river visible in the distance.
+ The scene captures the solitude and beauty of a winter drive through a mountainous region.
+ """
+ negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+ expected_height, expected_width = 768, 1152
+ downscale_factor = 2 / 3
+ num_frames = 161
+
+ # 1. Generate video at smaller resolution
+ # Text-only conditioning is also supported without the need to pass `conditions`
+ downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
+ downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
+ latents = pipeline(
+ conditions=[condition1],
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=downscaled_width,
+ height=downscaled_height,
+ num_frames=num_frames,
+ num_inference_steps=30,
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=5.0,
+ guidance_rescale=0.7,
+ generator=torch.Generator().manual_seed(0),
+ output_type="latent",
+ ).frames
+
+ # 2. Upscale generated video using latent upsampler with fewer inference steps
+ # The available latent upsampler upscales the height/width by 2x
+ upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
+ upscaled_latents = pipe_upsample(
+ latents=latents,
+ output_type="latent"
+ ).frames
+
+ # 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
+ video = pipeline(
+ conditions=[condition1],
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=upscaled_width,
+ height=upscaled_height,
+ num_frames=num_frames,
+ denoise_strength=0.4, # Effectively, 4 inference steps out of 10
+ num_inference_steps=10,
+ latents=upscaled_latents,
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=5.0,
+ guidance_rescale=0.7,
+ generator=torch.Generator().manual_seed(0),
+ output_type="pil",
+ ).frames[0]
+
+ # 4. Downscale the video to the expected resolution
+ video = [frame.resize((expected_width, expected_height)) for frame in video]
+
+ export_to_video(video, "output.mp4", fps=24)
+ ```
+
+
+
+- LTX-Video 0.9.7 distilled model is guidance and timestep-distilled to speedup generation. It requires `guidance_scale` to be set to `1.0` and `num_inference_steps` should be set between `4` and `10` for good generation quality. You should also use the following custom timesteps for the best results.
+
+ - Base model inference to prepare for upscaling: `[1000, 993, 987, 981, 975, 909, 725, 0.03]`.
+ - Upscaling: `[1000, 909, 725, 421, 0]`.
+
+
+ Show example code
+
+ ```py
+ import torch
+ from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
+ from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
+ from diffusers.utils import export_to_video, load_video
+
+ pipeline = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-distilled", torch_dtype=torch.bfloat16)
+ pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("Lightricks/ltxv-spatial-upscaler-0.9.7", vae=pipeline.vae, torch_dtype=torch.bfloat16)
+ pipeline.to("cuda")
+ pipe_upsample.to("cuda")
+ pipeline.vae.enable_tiling()
+
+ def round_to_nearest_resolution_acceptable_by_vae(height, width):
+ height = height - (height % pipeline.vae_spatial_compression_ratio)
+ width = width - (width % pipeline.vae_spatial_compression_ratio)
+ return height, width
+
+ prompt = """
+ artistic anatomical 3d render, utlra quality, human half full male body with transparent
+ skin revealing structure instead of organs, muscular, intricate creative patterns,
+ monochromatic with backlighting, lightning mesh, scientific concept art, blending biology
+ with botany, surreal and ethereal quality, unreal engine 5, ray tracing, ultra realistic,
+ 16K UHD, rich details. camera zooms out in a rotating fashion
+ """
+ negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
+ expected_height, expected_width = 768, 1152
+ downscale_factor = 2 / 3
+ num_frames = 161
+
+ # 1. Generate video at smaller resolution
+ downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
+ downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
+ latents = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=downscaled_width,
+ height=downscaled_height,
+ num_frames=num_frames,
+ timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=1.0,
+ guidance_rescale=0.7,
+ generator=torch.Generator().manual_seed(0),
+ output_type="latent",
+ ).frames
+
+ # 2. Upscale generated video using latent upsampler with fewer inference steps
+ # The available latent upsampler upscales the height/width by 2x
+ upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
+ upscaled_latents = pipe_upsample(
+ latents=latents,
+ adain_factor=1.0,
+ output_type="latent"
+ ).frames
+
+ # 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
+ video = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=upscaled_width,
+ height=upscaled_height,
+ num_frames=num_frames,
+ denoise_strength=0.999, # Effectively, 4 inference steps out of 5
+ timesteps=[1000, 909, 725, 421, 0],
+ latents=upscaled_latents,
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=1.0,
+ guidance_rescale=0.7,
+ generator=torch.Generator().manual_seed(0),
+ output_type="pil",
+ ).frames[0]
+
+ # 4. Downscale the video to the expected resolution
+ video = [frame.resize((expected_width, expected_height)) for frame in video]
+
+ export_to_video(video, "output.mp4", fps=24)
+ ```
+
+
+
+- LTX-Video 0.9.8 distilled model is similar to the 0.9.7 variant. It is guidance and timestep-distilled, and similar inference code can be used as above. An improvement of this version is that it supports generating very long videos. Additionally, it supports using tone mapping to improve the quality of the generated video using the `tone_map_compression_ratio` parameter. The default value of `0.6` is recommended.
+
+
+ Show example code
+
+ ```python
+ import torch
+ from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
+ from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
+ from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
+ from diffusers.utils import export_to_video, load_video
+
+ pipeline = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.8-13B-distilled", torch_dtype=torch.bfloat16)
+ # TODO: Update the checkpoint here once updated in LTX org
+ upsampler = LTXLatentUpsamplerModel.from_pretrained("a-r-r-o-w/LTX-0.9.8-Latent-Upsampler", torch_dtype=torch.bfloat16)
+ pipe_upsample = LTXLatentUpsamplePipeline(vae=pipeline.vae, latent_upsampler=upsampler).to(torch.bfloat16)
+ pipeline.to("cuda")
+ pipe_upsample.to("cuda")
+ pipeline.vae.enable_tiling()
+
+ def round_to_nearest_resolution_acceptable_by_vae(height, width):
+ height = height - (height % pipeline.vae_spatial_compression_ratio)
+ width = width - (width % pipeline.vae_spatial_compression_ratio)
+ return height, width
+
+ prompt = """The camera pans over a snow-covered mountain range, revealing a vast expanse of snow-capped peaks and valleys.The mountains are covered in a thick layer of snow, with some areas appearing almost white while others have a slightly darker, almost grayish hue. The peaks are jagged and irregular, with some rising sharply into the sky while others are more rounded. The valleys are deep and narrow, with steep slopes that are also covered in snow. The trees in the foreground are mostly bare, with only a few leaves remaining on their branches. The sky is overcast, with thick clouds obscuring the sun. The overall impression is one of peace and tranquility, with the snow-covered mountains standing as a testament to the power and beauty of nature."""
+ # prompt = """A woman walks away from a white Jeep parked on a city street at night, then ascends a staircase and knocks on a door. The woman, wearing a dark jacket and jeans, walks away from the Jeep parked on the left side of the street, her back to the camera; she walks at a steady pace, her arms swinging slightly by her sides; the street is dimly lit, with streetlights casting pools of light on the wet pavement; a man in a dark jacket and jeans walks past the Jeep in the opposite direction; the camera follows the woman from behind as she walks up a set of stairs towards a building with a green door; she reaches the top of the stairs and turns left, continuing to walk towards the building; she reaches the door and knocks on it with her right hand; the camera remains stationary, focused on the doorway; the scene is captured in real-life footage."""
+ negative_prompt = "bright colors, symbols, graffiti, watermarks, worst quality, inconsistent motion, blurry, jittery, distorted"
+ expected_height, expected_width = 480, 832
+ downscale_factor = 2 / 3
+ # num_frames = 161
+ num_frames = 361
+
+ # 1. Generate video at smaller resolution
+ downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
+ downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
+ latents = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=downscaled_width,
+ height=downscaled_height,
+ num_frames=num_frames,
+ timesteps=[1000, 993, 987, 981, 975, 909, 725, 0.03],
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=1.0,
+ guidance_rescale=0.7,
+ generator=torch.Generator().manual_seed(0),
+ output_type="latent",
+ ).frames
+
+ # 2. Upscale generated video using latent upsampler with fewer inference steps
+ # The available latent upsampler upscales the height/width by 2x
+ upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
+ upscaled_latents = pipe_upsample(
+ latents=latents,
+ adain_factor=1.0,
+ tone_map_compression_ratio=0.6,
+ output_type="latent"
+ ).frames
+
+ # 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
+ video = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=upscaled_width,
+ height=upscaled_height,
+ num_frames=num_frames,
+ denoise_strength=0.999, # Effectively, 4 inference steps out of 5
+ timesteps=[1000, 909, 725, 421, 0],
+ latents=upscaled_latents,
+ decode_timestep=0.05,
+ decode_noise_scale=0.025,
+ image_cond_noise_scale=0.0,
+ guidance_scale=1.0,
+ guidance_rescale=0.7,
+ generator=torch.Generator().manual_seed(0),
+ output_type="pil",
+ ).frames[0]
+
+ # 4. Downscale the video to the expected resolution
+ video = [frame.resize((expected_width, expected_height)) for frame in video]
+
+ export_to_video(video, "output.mp4", fps=24)
+ ```
+
+
+
+- LTX-Video supports LoRAs with [`~loaders.LTXVideoLoraLoaderMixin.load_lora_weights`].
+
+
+ Show example code
+
+ ```py
+ import torch
+ from diffusers import LTXConditionPipeline
+ from diffusers.utils import export_to_video, load_image
+
+ pipeline = LTXConditionPipeline.from_pretrained(
+ "Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16
+ )
+
+ pipeline.load_lora_weights("Lightricks/LTX-Video-Cakeify-LoRA", adapter_name="cakeify")
+ pipeline.set_adapters("cakeify")
+
+ # use "CAKEIFY" to trigger the LoRA
+ prompt = "CAKEIFY a person using a knife to cut a cake shaped like a Pikachu plushie"
+ image = load_image("https://huggingface.co/Lightricks/LTX-Video-Cakeify-LoRA/resolve/main/assets/images/pikachu.png")
+
+ video = pipeline(
+ prompt=prompt,
+ image=image,
+ width=576,
+ height=576,
+ num_frames=161,
+ decode_timestep=0.03,
+ decode_noise_scale=0.025,
+ num_inference_steps=50,
+ ).frames[0]
+ export_to_video(video, "output.mp4", fps=26)
+ ```
+
+
+
+- LTX-Video supports loading from single files, such as [GGUF checkpoints](../../quantization/gguf), with [`loaders.FromOriginalModelMixin.from_single_file`] or [`loaders.FromSingleFileMixin.from_single_file`].
+
+
+ Show example code
+
+ ```py
+ import torch
+ from diffusers.utils import export_to_video
+ from diffusers import LTXPipeline, AutoModel, GGUFQuantizationConfig
+
+ transformer = AutoModel.from_single_file(
+ "https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ torch_dtype=torch.bfloat16
+ )
+ pipeline = LTXPipeline.from_pretrained(
+ "Lightricks/LTX-Video",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16
+ )
+ ```
+
+
## LTXPipeline
@@ -202,6 +492,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all
- __call__
+## LTXLatentUpsamplePipeline
+
+[[autodoc]] LTXLatentUpsamplePipeline
+ - all
+ - __call__
+
## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
diff --git a/docs/source/en/api/pipelines/lumina.md b/docs/source/en/api/pipelines/lumina.md
index ce5cf8b103cc..0a236d213d6c 100644
--- a/docs/source/en/api/pipelines/lumina.md
+++ b/docs/source/en/api/pipelines/lumina.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# MusicLDM
MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
@@ -40,11 +43,8 @@ During inference:
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
* The _length_ of the generated audio sample can be controlled by varying the `audio_length_in_s` argument.
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## MusicLDMPipeline
[[autodoc]] MusicLDMPipeline
diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md
index 114e3753e710..4fac5c789a25 100644
--- a/docs/source/en/api/pipelines/omnigen.md
+++ b/docs/source/en/api/pipelines/omnigen.md
@@ -1,4 +1,4 @@
-
+
+# Ovis-Image
+
+
+
+Ovis-Image is a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints.
+
+[Ovis-Image Technical Report](https://arxiv.org/abs/2511.22982) from Alibaba Group, by Guo-Hua Wang, Liangfu Cao, Tianyu Cui, Minghao Fu, Xiaohao Chen, Pengxin Zhan, Jianshan Zhao, Lan Li, Bowen Fu, Jiaqi Liu, Qing-Guo Chen.
+
+The abstract from the paper is:
+
+*We introduce Ovis-Image, a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints. Built upon our previous Ovis-U1 framework, Ovis-Image integrates a diffusion-based visual decoder with the stronger Ovis 2.5 multimodal backbone, leveraging a text-centric training pipeline that combines large-scale pre-training with carefully tailored post-training refinements. Despite its compact architecture, Ovis-Image achieves text rendering performance on par with significantly larger open models such as Qwen-Image and approaches closed-source systems like Seedream and GPT4o. Crucially, the model remains deployable on a single high-end GPU with moderate memory, narrowing the gap between frontier-level text rendering and practical deployment. Our results indicate that combining a strong multimodal backbone with a carefully designed, text-focused training recipe is sufficient to achieve reliable bilingual text rendering without resorting to oversized or proprietary models.*
+
+**Highlights**:
+
+* **Strong text rendering at a compact 7B scale**: Ovis-Image is a 7B text-to-image model that delivers text rendering quality comparable to much larger 20B-class systems such as Qwen-Image and competitive with leading closed-source models like GPT4o in text-centric scenarios, while remaining small enough to run on widely accessible hardware.
+* **High fidelity on text-heavy, layout-sensitive prompts**: The model excels on prompts that demand tight alignment between linguistic content and rendered typography (e.g., posters, banners, logos, UI mockups, infographics), producing legible, correctly spelled, and semantically consistent text across diverse fonts, sizes, and aspect ratios without compromising overall visual quality.
+* **Efficiency and deployability**: With its 7B parameter budget and streamlined architecture, Ovis-Image fits on a single high-end GPU with moderate memory, supports low-latency interactive use, and scales to batch production serving, bringing near–frontier text rendering to applications where tens-of-billions–parameter models are impractical.
+
+
+This pipeline was contributed by Ovis-Image Team. The original codebase can be found [here](https://github.com/AIDC-AI/Ovis-Image).
+
+Available models:
+
+| Model | Recommended dtype |
+|:-----:|:-----------------:|
+| [`AIDC-AI/Ovis-Image-7B`](https://huggingface.co/AIDC-AI/Ovis-Image-7B) | `torch.bfloat16` |
+
+Refer to [this](https://huggingface.co/collections/AIDC-AI/ovis-image) collection for more information.
+
+## OvisImagePipeline
+
+[[autodoc]] OvisImagePipeline
+ - all
+ - __call__
+
+## OvisImagePipelineOutput
+
+[[autodoc]] pipelines.ovis_image.pipeline_output.OvisImagePipelineOutput
diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md
index 64aefdf7e78f..35004b6ad39c 100644
--- a/docs/source/en/api/pipelines/pag.md
+++ b/docs/source/en/api/pipelines/pag.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Paint by Example
[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) is by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen.
@@ -24,11 +27,8 @@ The original codebase can be found at [Fantasy-Studio/Paint-by-Example](https://
Paint by Example is supported by the official [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example) checkpoint. The checkpoint is warm-started from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) to inpaint partly masked images conditioned on example and reference images.
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## PaintByExamplePipeline
[[autodoc]] PaintByExamplePipeline
diff --git a/docs/source/en/api/pipelines/panorama.md b/docs/source/en/api/pipelines/panorama.md
index cbd5aaf815db..b65e05dd0b51 100644
--- a/docs/source/en/api/pipelines/panorama.md
+++ b/docs/source/en/api/pipelines/panorama.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# MultiDiffusion
@@ -39,11 +42,8 @@ For example, without circular padding, there is a stitching artifact (default):
But with circular padding, the right and the left parts are matching (`circular_padding=True`):

-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionPanoramaPipeline
[[autodoc]] StableDiffusionPanoramaPipeline
diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md
index 86c0e8eb191a..eebfa4d4f8a6 100644
--- a/docs/source/en/api/pipelines/pia.md
+++ b/docs/source/en/api/pipelines/pia.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Image-to-Video Generation with PIA (Personalized Image Animator)
@@ -18,7 +21,7 @@ specific language governing permissions and limitations under the License.
## Overview
-[PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models](https://arxiv.org/abs/2312.13964) by Yiming Zhang, Zhening Xing, Yanhong Zeng, Youqing Fang, Kai Chen
+[PIA: Your Personalized Image Animator via Plug-and-Play Modules in Text-to-Image Models](https://huggingface.co/papers/2312.13964) by Yiming Zhang, Zhening Xing, Yanhong Zeng, Youqing Fang, Kai Chen
Recent advancements in personalized text-to-image (T2I) models have revolutionized content creation, empowering non-experts to generate stunning images with unique styles. While promising, adding realistic motions into these personalized images by text poses significant challenges in preserving distinct styles, high-fidelity details, and achieving motion controllability by text. In this paper, we present PIA, a Personalized Image Animator that excels in aligning with condition images, achieving motion controllability by text, and the compatibility with various personalized T2I models without specific tuning. To achieve these goals, PIA builds upon a base T2I model with well-trained temporal alignment layers, allowing for the seamless transformation of any personalized T2I model into an image animation model. A key component of PIA is the introduction of the condition module, which utilizes the condition frame and inter-frame affinity as input to transfer appearance information guided by the affinity hint for individual frame synthesis in the latent space. This design mitigates the challenges of appearance-related image alignment within and allows for a stronger focus on aligning with motion-related guidance.
@@ -84,15 +87,12 @@ Here are some sample outputs:
-
-
-If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the PIA checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
-
-
+> [!TIP]
+> If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the PIA checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
## Using FreeInit
-[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://arxiv.org/abs/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.
+[FreeInit: Bridging Initialization Gap in Video Diffusion Models](https://huggingface.co/papers/2312.07537) by Tianxing Wu, Chenyang Si, Yuming Jiang, Ziqi Huang, Ziwei Liu.
FreeInit is an effective method that improves temporal consistency and overall quality of videos generated using video-diffusion-models without any addition training. It can be applied to PIA, AnimateDiff, ModelScope, VideoCrafter and various other video generation models seamlessly at inference time, and works by iteratively refining the latent-initialization noise. More details can be found it the paper.
@@ -146,11 +146,8 @@ export_to_gif(frames, "pia-freeinit-animation.gif")
-
-
-FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
-
-
+> [!WARNING]
+> FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
## PIAPipeline
diff --git a/docs/source/en/api/pipelines/pix2pix.md b/docs/source/en/api/pipelines/pix2pix.md
index d0b3bf32b823..84eb0cb5e5d3 100644
--- a/docs/source/en/api/pipelines/pix2pix.md
+++ b/docs/source/en/api/pipelines/pix2pix.md
@@ -1,4 +1,4 @@
-
+
+# PRX
+
+
+PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
+
+## Available models
+
+PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
+
+
+| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
+|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
+| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
+| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
+| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
+| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
+| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
+| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
+| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
+| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
+
+Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
+
+## Loading the pipeline
+
+Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
+
+```py
+from diffusers.pipelines.prx import PRXPipeline
+
+# Load pipeline - VAE and text encoder will be loaded from HuggingFace
+pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+prompt = "A front-facing portrait of a lion the golden savanna at sunset."
+image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
+image.save("prx_output.png")
+```
+
+### Manual Component Loading
+
+Load components individually to customize the pipeline for instance to use quantized models.
+
+```py
+import torch
+from diffusers.pipelines.prx import PRXPipeline
+from diffusers.models import AutoencoderKL, AutoencoderDC
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from transformers import T5GemmaModel, GemmaTokenizerFast
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+# Load transformer
+transformer = PRXTransformer2DModel.from_pretrained(
+ "checkpoints/prx-512-t2i-sft",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+)
+
+# Load scheduler
+scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ "checkpoints/prx-512-t2i-sft", subfolder="scheduler"
+)
+
+# Load T5Gemma text encoder
+t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16)
+text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
+tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
+tokenizer.model_max_length = 256
+
+# Load VAE - choose either Flux VAE or DC-AE
+# Flux VAE
+vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
+ subfolder="vae",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16)
+
+pipe = PRXPipeline(
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae
+)
+pipe.to("cuda")
+```
+
+
+## Memory Optimization
+
+For memory-constrained environments:
+
+```py
+import torch
+from diffusers.pipelines.prx import PRXPipeline
+
+pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
+pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
+
+# Or use sequential CPU offload for even lower memory
+pipe.enable_sequential_cpu_offload()
+```
+
+## PRXPipeline
+
+[[autodoc]] PRXPipeline
+ - all
+ - __call__
+
+## PRXPipelineOutput
+
+[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md
new file mode 100644
index 000000000000..b3dd3dd93618
--- /dev/null
+++ b/docs/source/en/api/pipelines/qwenimage.md
@@ -0,0 +1,161 @@
+
+
+# QwenImage
+
+
+
+
+
+Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
+
+Qwen-Image comes in the following variants:
+
+| model type | model id |
+|:----------:|:--------:|
+| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
+| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
+| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
+
+> [!TIP]
+> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
+
+## LoRA for faster inference
+
+Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
+number of steps. Refer to the code snippet below:
+
+
+Code
+
+```py
+from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
+import torch
+import math
+
+ckpt_id = "Qwen/Qwen-Image"
+
+# From
+# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
+scheduler_config = {
+ "base_image_seq_len": 256,
+ "base_shift": math.log(3), # We use shift=3 in distillation
+ "invert_sigmas": False,
+ "max_image_seq_len": 8192,
+ "max_shift": math.log(3), # We use shift=3 in distillation
+ "num_train_timesteps": 1000,
+ "shift": 1.0,
+ "shift_terminal": None, # set shift_terminal to None
+ "stochastic_sampling": False,
+ "time_shift_type": "exponential",
+ "use_beta_sigmas": False,
+ "use_dynamic_shifting": True,
+ "use_exponential_sigmas": False,
+ "use_karras_sigmas": False,
+}
+scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
+pipe = DiffusionPipeline.from_pretrained(
+ ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
+).to("cuda")
+pipe.load_lora_weights(
+ "lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
+negative_prompt = " "
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=1024,
+ height=1024,
+ num_inference_steps=8,
+ true_cfg_scale=1.0,
+ generator=torch.manual_seed(0),
+).images[0]
+image.save("qwen_fewsteps.png")
+```
+
+
+
+> [!TIP]
+> The `guidance_scale` parameter in the pipeline is there to support future guidance-distilled models when they come up. Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should enable classifier-free guidance computations.
+
+## Multi-image reference with QwenImageEditPlusPipeline
+
+With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
+
+```
+import torch
+from PIL import Image
+from diffusers import QwenImageEditPlusPipeline
+from diffusers.utils import load_image
+
+pipe = QwenImageEditPlusPipeline.from_pretrained(
+ "Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
+).to("cuda")
+
+image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
+image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
+image = pipe(
+ image=[image_1, image_2],
+ prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
+ num_inference_steps=50
+).images[0]
+```
+
+## QwenImagePipeline
+
+[[autodoc]] QwenImagePipeline
+ - all
+ - __call__
+
+## QwenImageImg2ImgPipeline
+
+[[autodoc]] QwenImageImg2ImgPipeline
+ - all
+ - __call__
+
+## QwenImageInpaintPipeline
+
+[[autodoc]] QwenImageInpaintPipeline
+ - all
+ - __call__
+
+## QwenImageEditPipeline
+
+[[autodoc]] QwenImageEditPipeline
+ - all
+ - __call__
+
+## QwenImageEditInpaintPipeline
+
+[[autodoc]] QwenImageEditInpaintPipeline
+ - all
+ - __call__
+
+## QwenImageControlNetPipeline
+
+[[autodoc]] QwenImageControlNetPipeline
+ - all
+ - __call__
+
+## QwenImageEditPlusPipeline
+
+[[autodoc]] QwenImageEditPlusPipeline
+ - all
+ - __call__
+
+## QwenImagePipelineOutput
+
+[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/sana.md b/docs/source/en/api/pipelines/sana.md
index 3702b2771974..a948620f96cb 100644
--- a/docs/source/en/api/pipelines/sana.md
+++ b/docs/source/en/api/pipelines/sana.md
@@ -1,4 +1,4 @@
-
-# SanaSprintPipeline
+# SANA-Sprint
@@ -24,12 +24,6 @@ The abstract from the paper is:
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
-
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
Available models:
@@ -88,12 +82,46 @@ image.save("sana.png")
Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
+## Image to Image
+
+The [`SanaSprintImg2ImgPipeline`] is a pipeline for image-to-image generation. It takes an input image and a prompt, and generates a new image based on the input image and the prompt.
+
+```py
+import torch
+from diffusers import SanaSprintImg2ImgPipeline
+from diffusers.utils.loading_utils import load_image
+
+image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
+)
+
+pipe = SanaSprintImg2ImgPipeline.from_pretrained(
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
+ torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+image = pipe(
+ prompt="a cute pink bear",
+ image=image,
+ strength=0.5,
+ height=832,
+ width=480
+).images[0]
+image.save("output.png")
+```
+
## SanaSprintPipeline
[[autodoc]] SanaSprintPipeline
- all
- __call__
+## SanaSprintImg2ImgPipeline
+
+[[autodoc]] SanaSprintImg2ImgPipeline
+ - all
+ - __call__
+
## SanaPipelineOutput
diff --git a/docs/source/en/api/pipelines/sana_video.md b/docs/source/en/api/pipelines/sana_video.md
new file mode 100644
index 000000000000..9e330c758318
--- /dev/null
+++ b/docs/source/en/api/pipelines/sana_video.md
@@ -0,0 +1,189 @@
+
+
+# Sana-Video
+
+
+
+
+
+
+[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
+
+The abstract from the paper is:
+
+*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*
+
+This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).
+
+Available models:
+
+| Model | Recommended dtype |
+|:-----:|:-----------------:|
+| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |
+
+Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.
+
+Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
+
+
+## Generation Pipelines
+
+
`
+
+
+The example below demonstrates how to use the text-to-video pipeline to generate a video using a text description.
+
+```python
+pipe = SanaVideoPipeline.from_pretrained(
+ "Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
+ torch_dtype=torch.bfloat16,
+)
+pipe.text_encoder.to(torch.bfloat16)
+pipe.vae.to(torch.float32)
+pipe.to("cuda")
+
+prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+motion_scale = 30
+motion_prompt = f" motion score: {motion_scale}."
+prompt = prompt + motion_prompt
+
+video = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=832,
+ frames=81,
+ guidance_scale=6,
+ num_inference_steps=50,
+ generator=torch.Generator(device="cuda").manual_seed(0),
+).frames[0]
+
+export_to_video(video, "sana_video.mp4", fps=16)
+```
+
+
+
+
+The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame.
+
+```python
+pipe = SanaImageToVideoPipeline.from_pretrained(
+ "Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
+ torch_dtype=torch.bfloat16,
+)
+pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
+pipe.vae.to(torch.float32)
+pipe.text_encoder.to(torch.bfloat16)
+pipe.to("cuda")
+
+image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
+prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
+negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+motion_scale = 30
+motion_prompt = f" motion score: {motion_scale}."
+prompt = prompt + motion_prompt
+
+motion_scale = 30.0
+
+video = pipe(
+ image=image,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=832,
+ frames=81,
+ guidance_scale=6,
+ num_inference_steps=50,
+ generator=torch.Generator(device="cuda").manual_seed(0),
+).frames[0]
+
+export_to_video(video, "sana-i2v.mp4", fps=16)
+```
+
+
+
+
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline
+from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+text_encoder_8bit = AutoModel.from_pretrained(
+ "Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
+ subfolder="text_encoder",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
+transformer_8bit = SanaVideoTransformer3DModel.from_pretrained(
+ "Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
+ subfolder="transformer",
+ quantization_config=quant_config,
+ torch_dtype=torch.float16,
+)
+
+pipeline = SanaVideoPipeline.from_pretrained(
+ "Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
+ text_encoder=text_encoder_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ device_map="balanced",
+)
+
+model_score = 30
+prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
+negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+motion_prompt = f" motion score: {model_score}."
+prompt = prompt + motion_prompt
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=480,
+ width=832,
+ num_frames=81,
+ guidance_scale=6.0,
+ num_inference_steps=50
+).frames[0]
+export_to_video(output, "sana-video-output.mp4", fps=16)
+```
+
+## SanaVideoPipeline
+
+[[autodoc]] SanaVideoPipeline
+ - all
+ - __call__
+
+
+## SanaImageToVideoPipeline
+
+[[autodoc]] SanaImageToVideoPipeline
+ - all
+ - __call__
+
+
+## SanaVideoPipelineOutput
+
+[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput
diff --git a/docs/source/en/api/pipelines/self_attention_guidance.md b/docs/source/en/api/pipelines/self_attention_guidance.md
index d656ce93f104..8d411598ae6d 100644
--- a/docs/source/en/api/pipelines/self_attention_guidance.md
+++ b/docs/source/en/api/pipelines/self_attention_guidance.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Self-Attention Guidance
[Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) is by Susung Hong et al.
@@ -20,11 +23,8 @@ The abstract from the paper is:
You can find additional information about Self-Attention Guidance on the [project page](https://ku-cvlab.github.io/Self-Attention-Guidance), [original codebase](https://github.com/KU-CVLAB/Self-Attention-Guidance), and try it out in a [demo](https://huggingface.co/spaces/susunghong/Self-Attention-Guidance) or [notebook](https://colab.research.google.com/github/SusungHong/Self-Attention-Guidance/blob/main/SAG_Stable.ipynb).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionSAGPipeline
[[autodoc]] StableDiffusionSAGPipeline
diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.md b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
index b9aacd3518d8..dda428e80f8f 100644
--- a/docs/source/en/api/pipelines/semantic_stable_diffusion.md
+++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Semantic Guidance
Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Text-to-Image Models using Semantic Guidance](https://huggingface.co/papers/2301.12247) and provides strong semantic control over image generation.
@@ -19,11 +22,8 @@ The abstract from the paper is:
*Text-to-image diffusion models have recently received a lot of interest for their astonishing ability to produce high-fidelity images from text only. However, achieving one-shot generation that aligns with the user's intent is nearly impossible, yet small changes to the input prompt often result in very different images. This leaves the user with little semantic control. To put the user in control, we show how to interact with the diffusion process to flexibly steer it along semantic directions. This semantic guidance (SEGA) generalizes to any generative architecture using classifier-free guidance. More importantly, it allows for subtle and extensive edits, changes in composition and style, as well as optimizing the overall artistic conception. We demonstrate SEGA's effectiveness on both latent and pixel-based diffusion models such as Stable Diffusion, Paella, and DeepFloyd-IF using a variety of tasks, thus providing strong evidence for its versatility, flexibility, and improvements over existing methods.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## SemanticStableDiffusionPipeline
[[autodoc]] SemanticStableDiffusionPipeline
diff --git a/docs/source/en/api/pipelines/shap_e.md b/docs/source/en/api/pipelines/shap_e.md
index 3c1f939c1fce..3e505894ca80 100644
--- a/docs/source/en/api/pipelines/shap_e.md
+++ b/docs/source/en/api/pipelines/shap_e.md
@@ -1,4 +1,4 @@
-
+
+
+
+# SkyReels-V2: Infinite-length Film Generative model
+
+[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team from Skywork AI.
+
+*Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).*
+
+You can find all the original SkyReels-V2 checkpoints under the [Skywork](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) organization.
+
+The following SkyReels-V2 models are supported in Diffusers:
+- [SkyReels-V2 DF 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers)
+- [SkyReels-V2 DF 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P-Diffusers)
+- [SkyReels-V2 DF 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P-Diffusers)
+- [SkyReels-V2 T2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P-Diffusers)
+- [SkyReels-V2 T2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P-Diffusers)
+- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
+- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
+- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
+- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
+
+> [!TIP]
+> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.
+
+### A _Visual_ Demonstration
+
+The example below has the following parameters:
+
+- `base_num_frames=97`
+- `num_frames=97`
+- `num_inference_steps=30`
+- `ar_step=5`
+- `causal_block_size=5`
+
+With `vae_scale_factor_temporal=4`, expect `5` blocks of `5` frames each as calculated by:
+
+`num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each`
+
+And the maximum context length in the latent space is calculated with `base_num_latent_frames`:
+
+`base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 -> 25//5 = 5 blocks`
+
+Asynchronous Processing Timeline:
+```text
+┌─────────────────────────────────────────────────────────────────┐
+│ Steps: 1 6 11 16 21 26 31 36 41 46 50 │
+│ Block 1: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+│ Block 2: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+│ Block 3: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+│ Block 4: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+│ Block 5: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+└─────────────────────────────────────────────────────────────────┘
+```
+
+For Long Videos (`num_frames` > `base_num_frames`):
+`base_num_frames` acts as the "sliding window size" for processing long videos.
+
+Example: `257`-frame video with `base_num_frames=97`, `overlap_history=17`
+```text
+┌──── Iteration 1 (frames 1-97) ────┐
+│ Processing window: 97 frames │ → 5 blocks,
+│ Generates: frames 1-97 │ async processing
+└───────────────────────────────────┘
+ ┌────── Iteration 2 (frames 81-177) ──────┐
+ │ Processing window: 97 frames │
+ │ Overlap: 17 frames (81-97) from prev │ → 5 blocks,
+ │ Generates: frames 98-177 │ async processing
+ └─────────────────────────────────────────┘
+ ┌────── Iteration 3 (frames 161-257) ──────┐
+ │ Processing window: 97 frames │
+ │ Overlap: 17 frames (161-177) from prev │ → 5 blocks,
+ │ Generates: frames 178-257 │ async processing
+ └──────────────────────────────────────────┘
+```
+
+Each iteration independently runs the asynchronous processing with its own `5` blocks.
+`base_num_frames` controls:
+1. Memory usage (larger window = more VRAM)
+2. Model context length (must match training constraints)
+3. Number of blocks per iteration (`base_num_latent_frames // causal_block_size`)
+
+Each block takes `30` steps to complete denoising.
+Block N starts at step: `1 + (N-1) x ar_step`
+Total steps: `30 + (5-1) x 5 = 50` steps
+
+
+Synchronous mode (`ar_step=0`) would process all blocks/frames simultaneously:
+```text
+┌──────────────────────────────────────────────┐
+│ Steps: 1 ... 30 │
+│ All blocks: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+└──────────────────────────────────────────────┘
+```
+Total steps: `30` steps
+
+
+An example on how the step matrix is constructed for asynchronous processing:
+Given the parameters: (`num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5`)
+```
+- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25
+- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,
+ 941, 932, 922, 912, 901, 888, 874, 859, 841, 822,
+ 799, 773, 743, 708, 666, 615, 551, 470, 363, 216]
+```
+
+The algorithm creates a `50x25` `step_matrix` where:
+```
+- Row 1: [999×5, 999×5, 999×5, 999×5, 999×5]
+- Row 2: [995×5, 999×5, 999×5, 999×5, 999×5]
+- Row 3: [991×5, 999×5, 999×5, 999×5, 999×5]
+- ...
+- Row 7: [969×5, 995×5, 999×5, 999×5, 999×5]
+- ...
+- Row 21: [799×5, 888×5, 941×5, 975×5, 999×5]
+- ...
+- Row 35: [ 0×5, 216×5, 666×5, 822×5, 901×5]
+- ...
+- Row 42: [ 0×5, 0×5, 0×5, 551×5, 773×5]
+- ...
+- Row 50: [ 0×5, 0×5, 0×5, 0×5, 216×5]
+```
+
+Detailed Row `6` Analysis:
+```
+- step_matrix[5]: [ 975×5, 999×5, 999×5, 999×5, 999×5]
+- step_index[5]: [ 6×5, 1×5, 0×5, 0×5, 0×5]
+- step_update_mask[5]: [True×5, True×5, False×5, False×5, False×5]
+- valid_interval[5]: (0, 25)
+```
+
+Key Pattern: Block `i` lags behind Block `i-1` by exactly `ar_step=5` timesteps, creating the
+staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks.
+
+
+### Text-to-Video Generation
+
+The example below demonstrates how to generate a video from text.
+
+
+
+
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
+
+From the original repo:
+>You can use --ar_step 5 to enable asynchronous inference. When asynchronous inference, --causal_block_size 5 is recommended while it is not supposed to be set for synchronous generation... Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
+
+```py
+import torch
+from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler
+from diffusers.utils import export_to_video
+
+
+model_id = "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers"
+vae = AutoModel.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+
+pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
+ model_id,
+ vae=vae,
+ torch_dtype=torch.bfloat16,
+)
+pipeline.to("cuda")
+flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
+
+prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+
+output = pipeline(
+ prompt=prompt,
+ num_inference_steps=30,
+ height=544, # 720 for 720P
+ width=960, # 1280 for 720P
+ num_frames=97,
+ base_num_frames=97, # 121 for 720P
+ ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
+ causal_block_size=5, # Number of frames in each block for asynchronous processing
+ overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations
+ addnoise_condition=20, # Improves consistency in long video generation
+).frames[0]
+export_to_video(output, "video.mp4", fps=24, quality=8)
+```
+
+
+
+
+### First-Last-Frame-to-Video Generation
+
+The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
+
+
+
+
+```python
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline, UniPCMultistepScheduler
+from diffusers.utils import export_to_video, load_image
+
+
+model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
+ model_id, vae=vae, torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
+pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
+
+first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
+last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
+
+def aspect_ratio_resize(image, pipeline, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
+
+def center_crop_resize(image, height, width):
+ # Calculate resize ratio to match first frame dimensions
+ resize_ratio = max(width / image.width, height / image.height)
+
+ # Resize the image
+ width = round(image.width * resize_ratio)
+ height = round(image.height * resize_ratio)
+ size = [width, height]
+ image = TF.center_crop(image, size)
+
+ return image, height, width
+
+first_frame, height, width = aspect_ratio_resize(first_frame, pipeline)
+if last_frame.size != first_frame.size:
+ last_frame, _, _ = center_crop_resize(last_frame, height, width)
+
+prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
+
+output = pipeline(
+ image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0
+).frames[0]
+export_to_video(output, "video.mp4", fps=24, quality=8)
+```
+
+
+
+
+
+### Video-to-Video Generation
+
+
+
+
+`SkyReelsV2DiffusionForcingVideoToVideoPipeline` extends a given video.
+
+```python
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPipeline, UniPCMultistepScheduler
+from diffusers.utils import export_to_video, load_video
+
+
+model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
+ model_id, vae=vae, torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
+pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
+
+video = load_video("input_video.mp4")
+
+prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
+
+output = pipeline(
+ video=video, prompt=prompt, height=720, width=1280, guidance_scale=5.0, overlap_history=17,
+ num_inference_steps=30, num_frames=257, base_num_frames=121#, ar_step=5, causal_block_size=5,
+).frames[0]
+export_to_video(output, "video.mp4", fps=24, quality=8)
+# Total frames will be the number of frames of the given video + 257
+```
+
+
+
+
+## Notes
+
+- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`].
+
+`SkyReelsV2Pipeline` and `SkyReelsV2ImageToVideoPipeline` are also available without Diffusion Forcing framework applied.
+
+
+## SkyReelsV2DiffusionForcingPipeline
+
+[[autodoc]] SkyReelsV2DiffusionForcingPipeline
+ - all
+ - __call__
+
+## SkyReelsV2DiffusionForcingImageToVideoPipeline
+
+[[autodoc]] SkyReelsV2DiffusionForcingImageToVideoPipeline
+ - all
+ - __call__
+
+## SkyReelsV2DiffusionForcingVideoToVideoPipeline
+
+[[autodoc]] SkyReelsV2DiffusionForcingVideoToVideoPipeline
+ - all
+ - __call__
+
+## SkyReelsV2Pipeline
+
+[[autodoc]] SkyReelsV2Pipeline
+ - all
+ - __call__
+
+## SkyReelsV2ImageToVideoPipeline
+
+[[autodoc]] SkyReelsV2ImageToVideoPipeline
+ - all
+ - __call__
+
+## SkyReelsV2PipelineOutput
+
+[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput
diff --git a/docs/source/en/api/pipelines/stable_audio.md b/docs/source/en/api/pipelines/stable_audio.md
index 1acb72b3968a..82763a52a942 100644
--- a/docs/source/en/api/pipelines/stable_audio.md
+++ b/docs/source/en/api/pipelines/stable_audio.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# GLIGEN (Grounded Language-to-Image Generation)
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
@@ -18,13 +21,10 @@ The abstract from the [paper](https://huggingface.co/papers/2301.07093) is:
*Large-scale text-to-image diffusion models have made amazing advances. However, the status quo is to use text input alone, which can impede controllability. In this work, we propose GLIGEN, Grounded-Language-to-Image Generation, a novel approach that builds upon and extends the functionality of existing pre-trained text-to-image diffusion models by enabling them to also be conditioned on grounding inputs. To preserve the vast concept knowledge of the pre-trained model, we freeze all of its weights and inject the grounding information into new trainable layers via a gated mechanism. Our model achieves open-world grounded text2img generation with caption and bounding box condition inputs, and the grounding ability generalizes well to novel spatial configurations and concepts. GLIGEN’s zeroshot performance on COCO and LVIS outperforms existing supervised layout-to-image baselines by a large margin.*
-
-
-Make sure to check out the Stable Diffusion [Tips](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality and how to reuse pipeline components efficiently!
-
-If you want to use one of the official checkpoints for a task, explore the [gligen](https://huggingface.co/gligen) Hub organizations!
-
-
+> [!TIP]
+> Make sure to check out the Stable Diffusion [Tips](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality and how to reuse pipeline components efficiently!
+>
+> If you want to use one of the official checkpoints for a task, explore the [gligen](https://huggingface.co/gligen) Hub organizations!
[`StableDiffusionGLIGENPipeline`] was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful) and [`StableDiffusionGLIGENTextImagePipeline`] was contributed by [Nguyễn Công Tú Anh](https://github.com/tuanh123789).
diff --git a/docs/source/en/api/pipelines/stable_diffusion/image_variation.md b/docs/source/en/api/pipelines/stable_diffusion/image_variation.md
index 57dd2f0d5b39..b1b7146b336f 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/image_variation.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/image_variation.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# K-Diffusion
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.md b/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.md
index 9abccd6e1347..19eae9a9ce44 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/latent_upscale.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Text-to-(RGB, depth)
@@ -19,7 +22,7 @@ specific language governing permissions and limitations under the License.
LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://huggingface.co/papers/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, and Vasudev Lal. LDM3D generates an image and a depth map from a given text prompt unlike the existing text-to-image diffusion models such as [Stable Diffusion](./overview) which only generates an image. With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.
Two checkpoints are available for use:
-- [ldm3d-original](https://huggingface.co/Intel/ldm3d). The original checkpoint used in the [paper](https://arxiv.org/pdf/2305.10853.pdf)
+- [ldm3d-original](https://huggingface.co/Intel/ldm3d). The original checkpoint used in the [paper](https://huggingface.co/papers/2305.10853)
- [ldm3d-4c](https://huggingface.co/Intel/ldm3d-4c). The new version of LDM3D using 4 channels inputs instead of 6-channels inputs and finetuned on higher resolution images.
@@ -27,11 +30,8 @@ The abstract from the paper is:
*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).*
-
-
-Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
-
-
+> [!TIP]
+> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
## StableDiffusionLDM3DPipeline
@@ -48,7 +48,7 @@ Make sure to check out the Stable Diffusion [Tips](overview#tips) section to lea
# Upscaler
-[LDM3D-VR](https://arxiv.org/pdf/2311.03226.pdf) is an extended version of LDM3D.
+[LDM3D-VR](https://huggingface.co/papers/2311.03226) is an extended version of LDM3D.
The abstract from the paper is:
*Latent diffusion models have proven to be state-of-the-art in the creation and manipulation of visual outputs. However, as far as we know, the generation of depth maps jointly with RGB is still limited. We introduce LDM3D-VR, a suite of diffusion models targeting virtual reality development that includes LDM3D-pano and LDM3D-SR. These models enable the generation of panoramic RGBD based on textual prompts and the upscaling of low-resolution inputs to high-resolution RGBD, respectively. Our models are fine-tuned from existing pretrained models on datasets containing panoramic/high-resolution RGB images, depth maps and captions. Both models are evaluated in comparison to existing related methods*
diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.md b/docs/source/en/api/pipelines/stable_diffusion/overview.md
index 25984091215c..2d2de39c91a8 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/overview.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/overview.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Safe Stable Diffusion
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content.
@@ -42,11 +45,8 @@ There are 4 configurations (`SafetyConfig.WEAK`, `SafetyConfig.MEDIUM`, `SafetyC
>>> out = pipeline(prompt=prompt, **SafetyConfig.MAX)
```
-
-
-Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
-
-
+> [!TIP]
+> Make sure to check out the Stable Diffusion [Tips](overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality, and how to reuse pipeline components efficiently!
## StableDiffusionPipelineSafe
diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md
index 485ee7d7fc28..6863d408b5fd 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md
@@ -1,4 +1,4 @@
-
-
-
-🧪 This pipeline is for research purposes only.
-
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Text-to-video
@@ -22,7 +19,7 @@ specific language governing permissions and limitations under the License.
-[ModelScope Text-to-Video Technical Report](https://arxiv.org/abs/2308.06571) is by Jiuniu Wang, Hangjie Yuan, Dayou Chen, Yingya Zhang, Xiang Wang, Shiwei Zhang.
+[ModelScope Text-to-Video Technical Report](https://huggingface.co/papers/2308.06571) is by Jiuniu Wang, Hangjie Yuan, Dayou Chen, Yingya Zhang, Xiang Wang, Shiwei Zhang.
The abstract from the paper is:
@@ -175,13 +172,10 @@ Here are some sample outputs:
Video generation is memory-intensive and one way to reduce your memory usage is to set `enable_forward_chunking` on the pipeline's UNet so you don't run the entire feedforward layer at once. Breaking it up into chunks in a loop is more efficient.
-Check out the [Text or image-to-video](text-img2vid) guide for more details about how certain parameters can affect video generation and how to optimize inference by reducing memory usage.
-
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+Check out the [Text or image-to-video](../../using-diffusers/text-img2vid) guide for more details about how certain parameters can affect video generation and how to optimize inference by reducing memory usage.
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## TextToVideoSDPipeline
[[autodoc]] TextToVideoSDPipeline
diff --git a/docs/source/en/api/pipelines/text_to_video_zero.md b/docs/source/en/api/pipelines/text_to_video_zero.md
index 44d9a6670af4..50e7620760f3 100644
--- a/docs/source/en/api/pipelines/text_to_video_zero.md
+++ b/docs/source/en/api/pipelines/text_to_video_zero.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Text2Video-Zero
@@ -34,7 +37,7 @@ Our key modifications include (i) enriching the latent codes of the generated fr
Experiments show that this leads to low overhead, yet high-quality and remarkably consistent video generation. Moreover, our approach is not limited to text-to-video synthesis but is also applicable to other tasks such as conditional and content-specialized video generation, and Video Instruct-Pix2Pix, i.e., instruction-guided video editing.
As experiments show, our method performs comparably or sometimes better than recent approaches, despite not being trained on additional video data.*
-You can find additional information about Text2Video-Zero on the [project page](https://text2video-zero.github.io/), [paper](https://arxiv.org/abs/2303.13439), and [original codebase](https://github.com/Picsart-AI-Research/Text2Video-Zero).
+You can find additional information about Text2Video-Zero on the [project page](https://text2video-zero.github.io/), [paper](https://huggingface.co/papers/2303.13439), and [original codebase](https://github.com/Picsart-AI-Research/Text2Video-Zero).
## Usage example
@@ -55,9 +58,9 @@ result = [(r * 255).astype("uint8") for r in result]
imageio.mimsave("video.mp4", result, fps=4)
```
You can change these parameters in the pipeline call:
-* Motion field strength (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1):
+* Motion field strength (see the [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1):
* `motion_field_strength_x` and `motion_field_strength_y`. Default: `motion_field_strength_x=12`, `motion_field_strength_y=12`
-* `T` and `T'` (see the [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1)
+* `T` and `T'` (see the [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1)
* `t0` and `t1` in the range `{0, ..., num_inference_steps}`. Default: `t0=45`, `t1=48`
* Video length:
* `video_length`, the number of frames video_length to be generated. Default: `video_length=8`
@@ -286,11 +289,8 @@ can run with custom [DreamBooth](../../training/dreambooth) models, as shown bel
You can filter out some available DreamBooth-trained models with [this link](https://huggingface.co/models?search=dreambooth).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## TextToVideoZeroPipeline
[[autodoc]] TextToVideoZeroPipeline
diff --git a/docs/source/en/api/pipelines/unclip.md b/docs/source/en/api/pipelines/unclip.md
index 943cebdb28a2..7c5c2b0d9ab9 100644
--- a/docs/source/en/api/pipelines/unclip.md
+++ b/docs/source/en/api/pipelines/unclip.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# unCLIP
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).
@@ -17,11 +20,8 @@ The abstract from the paper is following:
You can find lucidrains' DALL-E 2 recreation at [lucidrains/DALLE2-pytorch](https://github.com/lucidrains/DALLE2-pytorch).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## UnCLIPPipeline
[[autodoc]] UnCLIPPipeline
diff --git a/docs/source/en/api/pipelines/unidiffuser.md b/docs/source/en/api/pipelines/unidiffuser.md
index 802aefea6be5..2ff700e4b8be 100644
--- a/docs/source/en/api/pipelines/unidiffuser.md
+++ b/docs/source/en/api/pipelines/unidiffuser.md
@@ -1,4 +1,4 @@
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# UniDiffuser
@@ -24,11 +27,8 @@ The abstract from the paper is:
You can find the original codebase at [thu-ml/unidiffuser](https://github.com/thu-ml/unidiffuser) and additional checkpoints at [thu-ml](https://huggingface.co/thu-ml).
-
-
-There is currently an issue on PyTorch 1.X where the output images are all black or the pixel values become `NaNs`. This issue can be mitigated by switching to PyTorch 2.X.
-
-
+> [!WARNING]
+> There is currently an issue on PyTorch 1.X where the output images are all black or the pixel values become `NaNs`. This issue can be mitigated by switching to PyTorch 2.X.
This pipeline was contributed by [dg845](https://github.com/dg845). ❤️
@@ -194,11 +194,8 @@ final_prompt = sample.text[0]
print(final_prompt)
```
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## UniDiffuserPipeline
[[autodoc]] UniDiffuserPipeline
diff --git a/docs/source/en/api/pipelines/value_guided_sampling.md b/docs/source/en/api/pipelines/value_guided_sampling.md
index 5aaee9090cef..d050ea309ca5 100644
--- a/docs/source/en/api/pipelines/value_guided_sampling.md
+++ b/docs/source/en/api/pipelines/value_guided_sampling.md
@@ -1,4 +1,4 @@
-
+
+# VisualCloze
+
+[VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning](https://huggingface.co/papers/2504.07960) is an innovative in-context learning based universal image generation framework that offers key capabilities:
+1. Support for various in-domain tasks
+2. Generalization to unseen tasks through in-context learning
+3. Unify multiple tasks into one step and generate both target image and intermediate results
+4. Support reverse-engineering conditions from target images
+
+## Overview
+
+The abstract from the paper is:
+
+*Recent progress in diffusion models significantly advances various image generation tasks. However, the current mainstream approach remains focused on building task-specific models, which have limited efficiency when supporting a wide range of different needs. While universal models attempt to address this limitation, they face critical challenges, including generalizable task instruction, appropriate task distributions, and unified architectural design. To tackle these challenges, we propose VisualCloze, a universal image generation framework, which supports a wide range of in-domain tasks, generalization to unseen ones, unseen unification of multiple tasks, and reverse generation. Unlike existing methods that rely on language-based task instruction, leading to task ambiguity and weak generalization, we integrate visual in-context learning, allowing models to identify tasks from visual demonstrations. Meanwhile, the inherent sparsity of visual task distributions hampers the learning of transferable knowledge across tasks. To this end, we introduce Graph200K, a graph-structured dataset that establishes various interrelated tasks, enhancing task density and transferable knowledge. Furthermore, we uncover that our unified image generation formulation shared a consistent objective with image infilling, enabling us to leverage the strong generative priors of pre-trained infilling models without modifying the architectures. The codes, dataset, and models are available at https://visualcloze.github.io.*
+
+## Inference
+
+### Model loading
+
+VisualCloze is a two-stage cascade pipeline, containing `VisualClozeGenerationPipeline` and `VisualClozeUpsamplingPipeline`.
+- In `VisualClozeGenerationPipeline`, each image is downsampled before concatenating images into a grid layout, avoiding excessively high resolutions. VisualCloze releases two models suitable for diffusers, i.e., [VisualClozePipeline-384](https://huggingface.co/VisualCloze/VisualClozePipeline-384) and [VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-384), which downsample images to resolutions of 384 and 512, respectively.
+- `VisualClozeUpsamplingPipeline` uses [SDEdit](https://huggingface.co/papers/2108.01073) to enable high-resolution image synthesis.
+
+The `VisualClozePipeline` integrates both stages to support convenient end-to-end sampling, while also allowing users to utilize each pipeline independently as needed.
+
+### Input Specifications
+
+#### Task and Content Prompts
+- Task prompt: Required to describe the generation task intention
+- Content prompt: Optional description or caption of the target image
+- When content prompt is not needed, pass `None`
+- For batch inference, pass `List[str|None]`
+
+#### Image Input Format
+- Format: `List[List[Image|None]]`
+- Structure:
+ - All rows except the last represent in-context examples
+ - Last row represents the current query (target image set to `None`)
+- For batch inference, pass `List[List[List[Image|None]]]`
+
+#### Resolution Control
+- Default behavior:
+ - Initial generation in the first stage: area of ${pipe.resolution}^2$
+ - Upsampling in the second stage: 3x factor
+- Custom resolution: Adjust using `upsampling_height` and `upsampling_width` parameters
+
+### Examples
+
+For comprehensive examples covering a wide range of tasks, please refer to the [Online Demo](https://huggingface.co/spaces/VisualCloze/VisualCloze) and [GitHub Repository](https://github.com/lzyhha/VisualCloze). Below are simple examples for three cases: mask-to-image conversion, edge detection, and subject-driven generation.
+
+#### Example for mask2image
+
+```python
+import torch
+from diffusers import VisualClozePipeline
+from diffusers.utils import load_image
+
+pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Load in-context images (make sure the paths are correct and accessible)
+image_paths = [
+ # in-context examples
+ [
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg'),
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg'),
+ ],
+ # query with the target image
+ [
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg'),
+ None, # No image needed for the target image
+ ],
+]
+
+# Task and content prompt
+task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
+content_prompt = """Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape.
+The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible.
+Its plumage is a mix of dark brown and golden hues, with intricate feather details.
+The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere.
+The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field,
+soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background,
+tranquil, majestic, wildlife photography."""
+
+# Run the pipeline
+image_result = pipe(
+ task_prompt=task_prompt,
+ content_prompt=content_prompt,
+ image=image_paths,
+ upsampling_width=1344,
+ upsampling_height=768,
+ upsampling_strength=0.4,
+ guidance_scale=30,
+ num_inference_steps=30,
+ max_sequence_length=512,
+ generator=torch.Generator("cpu").manual_seed(0)
+).images[0][0]
+
+# Save the resulting image
+image_result.save("visualcloze.png")
+```
+
+#### Example for edge-detection
+
+```python
+import torch
+from diffusers import VisualClozePipeline
+from diffusers.utils import load_image
+
+pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Load in-context images (make sure the paths are correct and accessible)
+image_paths = [
+ # in-context examples
+ [
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_image.jpg'),
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_edge.jpg'),
+ ],
+ [
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_image.jpg'),
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_edge.jpg'),
+ ],
+ # query with the target image
+ [
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_query_image.jpg'),
+ None, # No image needed for the target image
+ ],
+]
+
+# Task and content prompt
+task_prompt = "Each row illustrates a pathway from [IMAGE1] a sharp and beautifully composed photograph to [IMAGE2] edge map with natural well-connected outlines using a clear logical task."
+content_prompt = ""
+
+# Run the pipeline
+image_result = pipe(
+ task_prompt=task_prompt,
+ content_prompt=content_prompt,
+ image=image_paths,
+ upsampling_width=864,
+ upsampling_height=1152,
+ upsampling_strength=0.4,
+ guidance_scale=30,
+ num_inference_steps=30,
+ max_sequence_length=512,
+ generator=torch.Generator("cpu").manual_seed(0)
+).images[0][0]
+
+# Save the resulting image
+image_result.save("visualcloze.png")
+```
+
+#### Example for subject-driven generation
+
+```python
+import torch
+from diffusers import VisualClozePipeline
+from diffusers.utils import load_image
+
+pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+# Load in-context images (make sure the paths are correct and accessible)
+image_paths = [
+ # in-context examples
+ [
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_reference.jpg'),
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_depth.jpg'),
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_image.jpg'),
+ ],
+ [
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_reference.jpg'),
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_depth.jpg'),
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_image.jpg'),
+ ],
+ # query with the target image
+ [
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_reference.jpg'),
+ load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_depth.jpg'),
+ None, # No image needed for the target image
+ ],
+]
+
+# Task and content prompt
+task_prompt = """Each row describes a process that begins with [IMAGE1] an image containing the key object,
+[IMAGE2] depth map revealing gray-toned spatial layers and results in
+[IMAGE3] an image with artistic qualitya high-quality image with exceptional detail."""
+content_prompt = """A vintage porcelain collector's item. Beneath a blossoming cherry tree in early spring,
+this treasure is photographed up close, with soft pink petals drifting through the air and vibrant blossoms framing the scene."""
+
+# Run the pipeline
+image_result = pipe(
+ task_prompt=task_prompt,
+ content_prompt=content_prompt,
+ image=image_paths,
+ upsampling_width=1024,
+ upsampling_height=1024,
+ upsampling_strength=0.2,
+ guidance_scale=30,
+ num_inference_steps=30,
+ max_sequence_length=512,
+ generator=torch.Generator("cpu").manual_seed(0)
+).images[0][0]
+
+# Save the resulting image
+image_result.save("visualcloze.png")
+```
+
+#### Utilize each pipeline independently
+
+```python
+import torch
+from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
+from diffusers.utils import load_image
+from PIL import Image
+
+pipe = VisualClozeGenerationPipeline.from_pretrained(
+ "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+image_paths = [
+ # in-context examples
+ [
+ load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
+ ),
+ load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
+ ),
+ ],
+ # query with the target image
+ [
+ load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
+ ),
+ None, # No image needed for the target image
+ ],
+]
+task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
+content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
+
+# Stage 1: Generate initial image
+image = pipe(
+ task_prompt=task_prompt,
+ content_prompt=content_prompt,
+ image=image_paths,
+ guidance_scale=30,
+ num_inference_steps=30,
+ max_sequence_length=512,
+ generator=torch.Generator("cpu").manual_seed(0),
+).images[0][0]
+
+# Stage 2 (optional): Upsample the generated image
+pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
+pipe_upsample.to("cuda")
+
+mask_image = Image.new("RGB", image.size, (255, 255, 255))
+
+image = pipe_upsample(
+ image=image,
+ mask_image=mask_image,
+ prompt=content_prompt,
+ width=1344,
+ height=768,
+ strength=0.4,
+ guidance_scale=30,
+ num_inference_steps=30,
+ max_sequence_length=512,
+ generator=torch.Generator("cpu").manual_seed(0),
+).images[0]
+
+image.save("visualcloze.png")
+```
+
+## VisualClozePipeline
+
+[[autodoc]] VisualClozePipeline
+ - all
+ - __call__
+
+## VisualClozeGenerationPipeline
+
+[[autodoc]] VisualClozeGenerationPipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
index cb856fe0acfc..6aab6c5b33b9 100644
--- a/docs/source/en/api/pipelines/wan.md
+++ b/docs/source/en/api/pipelines/wan.md
@@ -1,4 +1,4 @@
-
+
+
# Wan
-
-
-
+[Wan-2.1](https://huggingface.co/papers/2503.20314) by the Wan Team.
-[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
+*This report presents Wan, a comprehensive and open suite of video foundation models designed to push the boundaries of video generation. Built upon the mainstream diffusion transformer paradigm, Wan achieves significant advancements in generative capabilities through a series of innovations, including our novel VAE, scalable pre-training strategies, large-scale data curation, and automated evaluation metrics. These contributions collectively enhance the model's performance and versatility. Specifically, Wan is characterized by four key features: Leading Performance: The 14B model of Wan, trained on a vast dataset comprising billions of images and videos, demonstrates the scaling laws of video generation with respect to both data and model size. It consistently outperforms the existing open-source models as well as state-of-the-art commercial solutions across multiple internal and external benchmarks, demonstrating a clear and significant performance superiority. Comprehensiveness: Wan offers two capable models, i.e., 1.3B and 14B parameters, for efficiency and effectiveness respectively. It also covers multiple downstream applications, including image-to-video, instruction-guided video editing, and personal video generation, encompassing up to eight tasks. Consumer-Grade Efficiency: The 1.3B model demonstrates exceptional resource efficiency, requiring only 8.19 GB VRAM, making it compatible with a wide range of consumer-grade GPUs. Openness: We open-source the entire series of Wan, including source code and all models, with the goal of fostering the growth of the video generation community. This openness seeks to significantly expand the creative possibilities of video production in the industry and provide academia with high-quality video foundation models. All the code and models are available at [this https URL](https://github.com/Wan-Video/Wan2.1).*
-
+You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization.
-## Generating Videos with Wan 2.1
+The following Wan models are supported in Diffusers:
-We will first need to install some addtional dependencies.
+- [Wan 2.1 T2V 1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
+- [Wan 2.1 T2V 14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
+- [Wan 2.1 I2V 14B - 480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
+- [Wan 2.1 I2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
+- [Wan 2.1 FLF2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
+- [Wan 2.1 VACE 1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers)
+- [Wan 2.1 VACE 14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers)
+- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
+- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
+- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
+- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
-```shell
-pip install -u ftfy imageio-ffmpeg imageio
-```
+> [!TIP]
+> Click on the Wan models in the right sidebar for more examples of video generation.
-### Text to Video Generation
+### Text-to-Video Generation
-The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
-for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
+The example below demonstrates how to generate a video from text optimized for memory or inference speed.
-```python
-from diffusers import WanPipeline
-from diffusers.utils import export_to_video
+
+
-# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
-model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
-pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
-pipe.enable_model_cpu_offload()
+The Wan2.1 text-to-video model below requires ~13GB of VRAM.
-prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
-negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
-num_frames = 33
+```py
+# pip install ftfy
+import torch
+import numpy as np
+from diffusers import AutoModel, WanPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+from diffusers.hooks.group_offloading import apply_group_offloading
+from diffusers.utils import export_to_video, load_image
+from transformers import UMT5EncoderModel
+
+text_encoder = UMT5EncoderModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
+transformer = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+
+# group-offloading
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+apply_group_offloading(text_encoder,
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="block_level",
+ num_blocks_per_group=4
+)
+transformer.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True
+)
-frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
-export_to_video(frames, "wan-t2v.mp4", fps=16)
+pipeline = WanPipeline.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-14B-Diffusers",
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
+
+prompt = """
+The camera rushes from far to near in a low-angle shot,
+revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
+"""
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=81,
+ guidance_scale=5.0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
```
-
-You can improve the quality of the generated video by running the decoding step in full precision.
-
+
+
-```python
-from diffusers import WanPipeline, AutoencoderKLWan
-from diffusers.utils import export_to_video
+[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
-model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
+```py
+# pip install ftfy
+import torch
+import numpy as np
+from diffusers import AutoModel, WanPipeline
+from diffusers.hooks.group_offloading import apply_group_offloading
+from diffusers.utils import export_to_video, load_image
+from transformers import UMT5EncoderModel
-vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+text_encoder = UMT5EncoderModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16)
+vae = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
+transformer = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
-# replace this with pipe.to("cuda") if you have sufficient VRAM
-pipe.enable_model_cpu_offload()
+pipeline = WanPipeline.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-14B-Diffusers",
+ vae=vae,
+ transformer=transformer,
+ text_encoder=text_encoder,
+ torch_dtype=torch.bfloat16
+)
+pipeline.to("cuda")
-prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
-negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
-num_frames = 33
+# torch.compile
+pipeline.transformer.to(memory_format=torch.channels_last)
+pipeline.transformer = torch.compile(
+ pipeline.transformer, mode="max-autotune", fullgraph=True
+)
-frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
-export_to_video(frames, "wan-t2v.mp4", fps=16)
+prompt = """
+The camera rushes from far to near in a low-angle shot,
+revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
+"""
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=81,
+ guidance_scale=5.0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
```
-### Image to Video Generation
+
+
-The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
-35GB of VRAM to run.
+### First-Last-Frame-to-Video Generation
+
+The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
+
+
+
```python
-import torch
import numpy as np
+import torch
+import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
-# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
-model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
-image_encoder = CLIPVisionModel.from_pretrained(
- model_id, subfolder="image_encoder", torch_dtype=torch.float32
-)
+
+model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
+image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
+pipe.to("cuda")
-# replace this with pipe.to("cuda") if you have sufficient VRAM
-pipe.enable_model_cpu_offload()
+first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
+last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
-image = load_image(
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
-)
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
-max_area = 480 * 832
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
+def center_crop_resize(image, height, width):
+ # Calculate resize ratio to match first frame dimensions
+ resize_ratio = max(width / image.width, height / image.height)
-prompt = (
- "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
- "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
-)
-negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+ # Resize the image
+ width = round(image.width * resize_ratio)
+ height = round(image.height * resize_ratio)
+ size = [width, height]
+ image = TF.center_crop(image, size)
+
+ return image, height, width
-num_frames = 33
+first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
+if last_frame.size != first_frame.size:
+ last_frame, _, _ = center_crop_resize(last_frame, height, width)
+
+prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipe(
- image=image,
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=num_frames,
- guidance_scale=5.0,
+ image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
).frames[0]
-export_to_video(output, "wan-i2v.mp4", fps=16)
+export_to_video(output, "output.mp4", fps=16)
```
-### Video to Video Generation
+
+
-```python
-import torch
-from diffusers.utils import load_video, export_to_video
-from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
+### Any-to-Video Controllable Generation
-# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
-model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
-vae = AutoencoderKLWan.from_pretrained(
- model_id, subfolder="vae", torch_dtype=torch.float32
-)
-pipe = WanVideoToVideoPipeline.from_pretrained(
- model_id, vae=vae, torch_dtype=torch.bfloat16
-)
-flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
-pipe.scheduler = UniPCMultistepScheduler.from_config(
- pipe.scheduler.config, flow_shift=flow_shift
-)
-# change to pipe.to("cuda") if you have sufficient VRAM
-pipe.enable_model_cpu_offload()
+Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include:
+- Control to Video (Depth, Pose, Sketch, Flow, Grayscale, Scribble, Layout, Boundary Box, etc.). Recommended library for preprocessing videos to obtain control videos: [huggingface/controlnet_aux]()
+- Image/Video to Video (first frame, last frame, starting clip, ending clip, random clips)
+- Inpainting and Outpainting
+- Subject to Video (faces, object, characters, etc.)
+- Composition to Video (reference anything, animate anything, swap anything, expand anything, move anything, etc.)
-prompt = "A robot standing on a mountain top. The sun is setting in the background"
-negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
-video = load_video(
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
-)
-output = pipe(
- video=video,
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=480,
- width=512,
- guidance_scale=7.0,
- strength=0.7,
-).frames[0]
+The code snippets available in [this](https://github.com/huggingface/diffusers/pull/11582) pull request demonstrate some examples of how videos can be generated with controllability signals.
-export_to_video(output, "wan-v2v.mp4", fps=16)
-```
+The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
-## Memory Optimizations for Wan 2.1
+
+
-Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
+### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
-We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
+[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
-### Group Offloading the Transformer and UMT5 Text Encoder
+*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
-Find more information about group offloading [here](../optimization/memory.md)
+The project page: https://humanaigc.github.io/wan-animate
-#### Block Level Group Offloading
+This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
-We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
+#### Usage
-The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
+The Wan-Animate pipeline supports two modes of operation:
-```python
-import torch
-import numpy as np
-from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
-from diffusers.hooks.group_offloading import apply_group_offloading
-from diffusers.utils import export_to_video, load_image
-from transformers import UMT5EncoderModel, CLIPVisionModel
+1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
+2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
-# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
-model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
-image_encoder = CLIPVisionModel.from_pretrained(
- model_id, subfolder="image_encoder", torch_dtype=torch.float32
-)
+##### Prerequisites
-text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
-vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
+Before using the pipeline, you need to preprocess your reference video to extract:
+- **Pose video**: Contains skeletal keypoints representing body motion
+- **Face video**: Contains facial feature representations for expression control
-onload_device = torch.device("cuda")
-offload_device = torch.device("cpu")
+For replacement mode, you additionally need:
+- **Background video**: The original video containing the scene
+- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
-apply_group_offloading(text_encoder,
- onload_device=onload_device,
- offload_device=offload_device,
- offload_type="block_level",
- num_blocks_per_group=4
-)
+> [!NOTE]
+> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release.
-transformer.enable_group_offload(
- onload_device=onload_device,
- offload_device=offload_device,
- offload_type="block_level",
- num_blocks_per_group=4,
-)
-pipe = WanImageToVideoPipeline.from_pretrained(
- model_id,
- vae=vae,
- transformer=transformer,
- text_encoder=text_encoder,
- image_encoder=image_encoder,
- torch_dtype=torch.bfloat16
-)
-# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
+The example below demonstrates how to use the Wan-Animate pipeline:
+
+
+
+
+```python
+import numpy as np
+import torch
+from diffusers import AutoencoderKLWan, WanAnimatePipeline
+from diffusers.utils import export_to_video, load_image, load_video
+
+model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
-image = load_image(
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
-)
+# Load character image and preprocessed videos
+image = load_image("path/to/character.jpg")
+pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
+face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
-max_area = 720 * 832
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
+# Resize image to match VAE constraints
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
-prompt = (
- "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
- "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
-)
-negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+image, height, width = aspect_ratio_resize(image, pipe)
-num_frames = 33
+prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
+negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
+# Generate animated video
output = pipe(
image=image,
+ pose_video=pose_video,
+ face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
- num_frames=num_frames,
- guidance_scale=5.0,
+ segment_frame_length=77,
+ guidance_scale=1.0,
+ mode="animate", # Animation mode (default)
).frames[0]
-
-export_to_video(output, "wan-i2v.mp4", fps=16)
+export_to_video(output, "animated_character.mp4", fps=30)
```
-#### Block Level Group Offloading with CUDA Streams
-
-We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
-
-In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
+
+
```python
-import torch
import numpy as np
-from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
-from diffusers.hooks.group_offloading import apply_group_offloading
-from diffusers.utils import export_to_video, load_image
-from transformers import UMT5EncoderModel, CLIPVisionModel
-
-# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
-model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
-image_encoder = CLIPVisionModel.from_pretrained(
- model_id, subfolder="image_encoder", torch_dtype=torch.float32
-)
+import torch
+from diffusers import AutoencoderKLWan, WanAnimatePipeline
+from diffusers.utils import export_to_video, load_image, load_video
-text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
+model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
-
-onload_device = torch.device("cuda")
-offload_device = torch.device("cpu")
-
-apply_group_offloading(text_encoder,
- onload_device=onload_device,
- offload_device=offload_device,
- offload_type="block_level",
- num_blocks_per_group=4
-)
-
-transformer.enable_group_offload(
- onload_device=onload_device,
- offload_device=offload_device,
- offload_type="leaf_level",
- use_stream=True
-)
-pipe = WanImageToVideoPipeline.from_pretrained(
- model_id,
- vae=vae,
- transformer=transformer,
- text_encoder=text_encoder,
- image_encoder=image_encoder,
- torch_dtype=torch.bfloat16
-)
-# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
+pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
-image = load_image(
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
-)
+# Load all required inputs for replacement mode
+image = load_image("path/to/new_character.jpg")
+pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
+face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
+background_video = load_video("path/to/background_video.mp4") # Original scene
+mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
-max_area = 720 * 832
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
+# Resize image to match video dimensions
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
-prompt = (
- "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
- "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
-)
-negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+image, height, width = aspect_ratio_resize(image, pipe)
-num_frames = 33
+prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
+negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
+# Replace character in background video
output = pipe(
image=image,
+ pose_video=pose_video,
+ face_video=face_video,
+ background_video=background_video,
+ mask_video=mask_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
- num_frames=num_frames,
- guidance_scale=5.0,
+ segment_frame_lengths=77,
+ guidance_scale=1.0,
+ mode="replace", # Replacement mode
).frames[0]
-
-export_to_video(output, "wan-i2v.mp4", fps=16)
+export_to_video(output, "character_replaced.mp4", fps=30)
```
-### Applying Layerwise Casting to the Transformer
-
-Find more information about layerwise casting [here](../optimization/memory.md)
-
-In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
-
-This example will require 20GB of VRAM.
+
+
```python
-import torch
import numpy as np
-from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
-from diffusers.hooks.group_offloading import apply_group_offloading
-from diffusers.utils import export_to_video, load_image
-from transformers import UMT5EncoderModel, CLIPVisionModel
+import torch
+from diffusers import AutoencoderKLWan, WanAnimatePipeline
+from diffusers.utils import export_to_video, load_image, load_video
-model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
-image_encoder = CLIPVisionModel.from_pretrained(
- model_id, subfolder="image_encoder", torch_dtype=torch.float32
-)
-text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
+model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
-transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
-transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
+image = load_image("path/to/character.jpg")
+pose_video = load_video("path/to/pose_video.mp4")
+face_video = load_video("path/to/face_video.mp4")
-pipe = WanImageToVideoPipeline.from_pretrained(
- model_id,
- vae=vae,
- transformer=transformer,
- text_encoder=text_encoder,
- image_encoder=image_encoder,
- torch_dtype=torch.bfloat16
-)
-pipe.enable_model_cpu_offload()
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
-
-max_area = 720 * 832
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
-prompt = (
- "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
- "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
-)
-negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
-num_frames = 33
+def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
+
+image, height, width = aspect_ratio_resize(image, pipe)
+
+prompt = "A person dancing energetically in a studio"
+negative_prompt = "blurry, low quality"
+
+# Advanced: Use temporal guidance and custom callback
+def callback_fn(pipe, step_index, timestep, callback_kwargs):
+ # You can modify latents or other tensors here
+ print(f"Step {step_index}, Timestep {timestep}")
+ return callback_kwargs
output = pipe(
image=image,
+ pose_video=pose_video,
+ face_video=face_video,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
- num_frames=num_frames,
+ segment_frame_length=77,
num_inference_steps=50,
guidance_scale=5.0,
+ prev_segment_conditioning_frames=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
+ callback_on_step_end=callback_fn,
+ callback_on_step_end_tensor_inputs=["latents"],
).frames[0]
-export_to_video(output, "wan-i2v.mp4", fps=16)
+export_to_video(output, "animated_advanced.mp4", fps=30)
```
-## Using a Custom Scheduler
+
+
-Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
+#### Key Parameters
-```python
-from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline
+- **mode**: Choose between `"animate"` (default) or `"replace"`
+- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
+- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.)
-scheduler_a = FlowMatchEulerDiscreteScheduler(shift=5.0)
-scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=4.0)
-pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler=
)
+## Notes
-# or,
-pipe.scheduler =
-```
+- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
-## Using Single File Loading with Wan 2.1
+
+ Show example code
-The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
-method.
+ ```py
+ # pip install ftfy
+ import torch
+ from diffusers import AutoModel, WanPipeline
+ from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
+ from diffusers.utils import export_to_video
-```python
-import torch
-from diffusers import WanPipeline, WanTransformer3DModel
+ vae = AutoModel.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
+ )
+ pipeline = WanPipeline.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", vae=vae, torch_dtype=torch.bfloat16
+ )
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(
+ pipeline.scheduler.config, flow_shift=5.0
+ )
+ pipeline.to("cuda")
-ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
-transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
+ pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie")
+ pipeline.set_adapters("steamboat-willie")
-pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
-```
+ pipeline.enable_model_cpu_offload()
-## Recommendations for Inference
-- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
-- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
-- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
+ # use "steamboat willie style" to trigger the LoRA
+ prompt = """
+ steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
+ revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+ for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+ Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+ shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
+ """
+
+ output = pipeline(
+ prompt=prompt,
+ num_frames=81,
+ guidance_scale=5.0,
+ ).frames[0]
+ export_to_video(output, "output.mp4", fps=16)
+ ```
+
+
+
+- [`WanTransformer3DModel`] and [`AutoencoderKLWan`] supports loading from single files with [`~loaders.FromSingleFileMixin.from_single_file`].
+
+
+ Show example code
+
+ ```py
+ # pip install ftfy
+ import torch
+ from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
+
+ vae = AutoencoderKLWan.from_single_file(
+ "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
+ )
+ transformer = WanTransformer3DModel.from_single_file(
+ "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
+ torch_dtype=torch.bfloat16
+ )
+ pipeline = WanPipeline.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
+ vae=vae,
+ transformer=transformer,
+ torch_dtype=torch.bfloat16
+ )
+ ```
+
+
+
+- Set the [`AutoencoderKLWan`] dtype to `torch.float32` for better decoding quality.
+
+- The number of frames per second (fps) or `k` should be calculated by `4 * k + 1`.
+
+- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images.
+
+- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
+
+- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.
## WanPipeline
@@ -460,6 +550,24 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transform
- all
- __call__
+## WanVACEPipeline
+
+[[autodoc]] WanVACEPipeline
+ - all
+ - __call__
+
+## WanVideoToVideoPipeline
+
+[[autodoc]] WanVideoToVideoPipeline
+ - all
+ - __call__
+
+## WanAnimatePipeline
+
+[[autodoc]] WanAnimatePipeline
+ - all
+ - __call__
+
## WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md
index da6ef2cffc28..2be3631d8456 100644
--- a/docs/source/en/api/pipelines/wuerstchen.md
+++ b/docs/source/en/api/pipelines/wuerstchen.md
@@ -1,4 +1,4 @@
-
+
+# Z-Image
+
+
+
+
+
+[Z-Image](https://huggingface.co/papers/2511.22699) is a powerful and highly efficient image generation model with 6B parameters. Currently there's only one model with two more to be released:
+
+|Model|Hugging Face|
+|---|---|
+|Z-Image-Turbo|https://huggingface.co/Tongyi-MAI/Z-Image-Turbo|
+
+## Z-Image-Turbo
+
+Z-Image-Turbo is a distilled version of Z-Image that matches or exceeds leading competitors with only 8 NFEs (Number of Function Evaluations). It offers sub-second inference latency on enterprise-grade H800 GPUs and fits comfortably within 16G VRAM consumer devices. It excels in photorealistic image generation, bilingual text rendering (English & Chinese), and robust instruction adherence.
+
+## Image-to-image
+
+Use [`ZImageImg2ImgPipeline`] to transform an existing image based on a text prompt.
+
+```python
+import torch
+from diffusers import ZImageImg2ImgPipeline
+from diffusers.utils import load_image
+
+pipe = ZImageImg2ImgPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+init_image = load_image(url).resize((1024, 1024))
+
+prompt = "A fantasy landscape with mountains and a river, detailed, vibrant colors"
+image = pipe(
+ prompt,
+ image=init_image,
+ strength=0.6,
+ num_inference_steps=9,
+ guidance_scale=0.0,
+ generator=torch.Generator("cuda").manual_seed(42),
+).images[0]
+image.save("zimage_img2img.png")
+```
+
+## ZImagePipeline
+
+[[autodoc]] ZImagePipeline
+ - all
+ - __call__
+
+## ZImageImg2ImgPipeline
+
+[[autodoc]] ZImageImg2ImgPipeline
+ - all
+ - __call__
diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md
index 2c728cff3c07..7fa7c7c9d016 100644
--- a/docs/source/en/api/quantization.md
+++ b/docs/source/en/api/quantization.md
@@ -1,4 +1,4 @@
-
+
+# AutoPipelineBlocks
+
+[`~modular_pipelines.AutoPipelineBlocks`] are a multi-block type containing blocks that support different workflows. It automatically selects which sub-blocks to run based on the input provided at runtime. This is typically used to package multiple workflows - text-to-image, image-to-image, inpaint - into a single pipeline for convenience.
+
+This guide shows how to create [`~modular_pipelines.AutoPipelineBlocks`].
+
+Create three [`~modular_pipelines.ModularPipelineBlocks`] for text-to-image, image-to-image, and inpainting. These represent the different workflows available in the pipeline.
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
+
+class TextToImageBlock(ModularPipelineBlocks):
+ model_name = "text2img"
+
+ @property
+ def inputs(self):
+ return [InputParam(name="prompt")]
+
+ @property
+ def intermediate_outputs(self):
+ return []
+
+ @property
+ def description(self):
+ return "I'm a text-to-image workflow!"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ print("running the text-to-image workflow")
+ # Add your text-to-image logic here
+ # For example: generate image from prompt
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+
+```py
+class ImageToImageBlock(ModularPipelineBlocks):
+ model_name = "img2img"
+
+ @property
+ def inputs(self):
+ return [InputParam(name="prompt"), InputParam(name="image")]
+
+ @property
+ def intermediate_outputs(self):
+ return []
+
+ @property
+ def description(self):
+ return "I'm an image-to-image workflow!"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ print("running the image-to-image workflow")
+ # Add your image-to-image logic here
+ # For example: transform input image based on prompt
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+
+```py
+class InpaintBlock(ModularPipelineBlocks):
+ model_name = "inpaint"
+
+ @property
+ def inputs(self):
+ return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
+
+ @property
+ def intermediate_outputs(self):
+ return []
+
+ @property
+ def description(self):
+ return "I'm an inpaint workflow!"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ print("running the inpaint workflow")
+ # Add your inpainting logic here
+ # For example: fill masked areas based on prompt
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+Create an [`~modular_pipelines.AutoPipelineBlocks`] class that includes a list of the sub-block classes and their corresponding block names.
+
+You also need to include `block_trigger_inputs`, a list of input names that trigger the corresponding block. If a trigger input is provided at runtime, then that block is selected to run. Use `None` to specify the default block to run if no trigger inputs are detected.
+
+Lastly, it is important to include a `description` that clearly explains which inputs trigger which workflow. This helps users understand how to run specific workflows.
+
+```py
+from diffusers.modular_pipelines import AutoPipelineBlocks
+
+class AutoImageBlocks(AutoPipelineBlocks):
+ # List of sub-block classes to choose from
+ block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
+ # Names for each block in the same order
+ block_names = ["inpaint", "img2img", "text2img"]
+ # Trigger inputs that determine which block to run
+ # - "mask" triggers inpaint workflow
+ # - "image" triggers img2img workflow (but only if mask is not provided)
+ # - if none of above, runs the text2img workflow (default)
+ block_trigger_inputs = ["mask", "image", None]
+ # Description is extremely important for AutoPipelineBlocks
+
+ def description(self):
+ return (
+ "Pipeline generates images given different types of conditions!\n"
+ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
+ + " - inpaint workflow is run when `mask` is provided.\n"
+ + " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
+ + " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
+ )
+```
+
+It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, it's conditional logic may be difficult to figure out if it isn't properly explained.
+
+Create an instance of `AutoImageBlocks`.
+
+```py
+auto_blocks = AutoImageBlocks()
+```
+
+For more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input.
+
+```py
+auto_blocks.get_execution_blocks("mask")
+```
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/components_manager.md b/docs/source/en/modular_diffusers/components_manager.md
new file mode 100644
index 000000000000..af53411b9533
--- /dev/null
+++ b/docs/source/en/modular_diffusers/components_manager.md
@@ -0,0 +1,190 @@
+
+
+# ComponentsManager
+
+The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), prevents duplicate model instances, and supports offloading.
+
+This guide will show you how to use [`ComponentsManager`] to manage components and device memory.
+
+## Add a component
+
+The [`ComponentsManager`] should be created alongside a [`ModularPipeline`] in either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
+
+> [!TIP]
+> The `collection` parameter is optional but makes it easier to organize and manage components.
+
+
+
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+
+comp = ComponentsManager()
+pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
+```
+
+
+
+
+```py
+from diffusers import ComponentsManager
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+components = ComponentsManager()
+t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
+```
+
+
+
+
+Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_components`]. The example below uses [`~ModularPipeline.load_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection
+
+```py
+pipe.load_components()
+pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
+```
+
+Use the [`~ModularPipeline.null_component_names`] property to identify any components that need to be loaded, retrieve them with [`~ComponentsManager.get_components_by_names`], and then call [`~ModularPipeline.update_components`] to add the missing components.
+
+```py
+pipe2.null_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
+
+comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
+pipe2.update_components(**comp_dict)
+```
+
+To add individual components, use the [`~ComponentsManager.add`] method. This registers a component with a unique id.
+
+```py
+from diffusers import AutoModel
+
+text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
+component_id = comp.add("text_encoder", text_encoder)
+comp
+```
+
+Use [`~ComponentsManager.remove`] to remove a component using their id.
+
+```py
+comp.remove("text_encoder_139917733042864")
+```
+
+## Retrieve a component
+
+The [`ComponentsManager`] provides several methods to retrieve registered components.
+
+### get_one
+
+The [`~ComponentsManager.get_one`] method returns a single component and supports pattern matching for the `name` parameter. If multiple components match, [`~ComponentsManager.get_one`] returns an error.
+
+| Pattern | Example | Description |
+|-------------|----------------------------------|-------------------------------------------|
+| exact | `comp.get_one(name="unet")` | exact name match |
+| wildcard | `comp.get_one(name="unet*")` | names starting with "unet" |
+| exclusion | `comp.get_one(name="!unet")` | exclude components named "unet" |
+| or | `comp.get_one(name="unet|vae")` | name is "unet" or "vae" |
+
+[`~ComponentsManager.get_one`] also filters components by the `collection` argument or `load_id` argument.
+
+```py
+comp.get_one(name="unet", collection="sdxl")
+```
+
+### get_components_by_names
+
+The [`~ComponentsManager.get_components_by_names`] method accepts a list of names and returns a dictionary mapping names to components. This is especially useful with [`ModularPipeline`] since they provide lists of required component names and the returned dictionary can be passed directly to [`~ModularPipeline.update_components`].
+
+```py
+component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
+{"text_encoder": component1, "unet": component2, "vae": component3}
+```
+
+## Duplicate detection
+
+It is recommended to load model components with [`ComponentSpec`] to assign components with a unique id that encodes their loading parameters. This allows [`ComponentsManager`] to automatically detect and prevent duplicate model instances even when different objects represent the same underlying checkpoint.
+
+```py
+from diffusers import ComponentSpec, ComponentsManager
+from transformers import CLIPTextModel
+
+comp = ComponentsManager()
+
+# Create ComponentSpec for the first text encoder
+spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
+# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from the same repo/subfolder)
+spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
+
+# Load and add both components - the manager will detect they're the same model
+comp.add("text_encoder", spec.load())
+comp.add("text_encoder_duplicated", spec_duplicated.load())
+```
+
+This returns a warning with instructions for removing the duplicate.
+
+```py
+ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('')`.
+'text_encoder_duplicated_139917580682672'
+```
+
+You could also add a component without using [`ComponentSpec`] and duplicate detection still works in most cases even if you're adding the same component under a different name.
+
+However, [`ComponentManager`] can't detect duplicates when you load the same component into different objects. In this case, you should load a model with [`ComponentSpec`].
+
+```py
+text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
+comp.add("text_encoder", text_encoder_2)
+'text_encoder_139917732983664'
+```
+
+## Collections
+
+Collections are labels assigned to components for better organization and management. Add a component to a collection with the `collection` argument in [`~ComponentsManager.add`].
+
+Only one component per name is allowed in each collection. Adding a second component with the same name automatically removes the first component.
+
+```py
+from diffusers import ComponentSpec, ComponentsManager
+
+comp = ComponentsManager()
+# Create ComponentSpec for the first UNet
+spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
+# Create ComponentSpec for a different UNet
+spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
+
+# Add both UNets to the same collection - the second one will replace the first
+comp.add("unet", spec.load(), collection="sdxl")
+comp.add("unet", spec2.load(), collection="sdxl")
+```
+
+This makes it convenient to work with node-based systems because you can:
+
+- Mark all models as loaded from one node with the `collection` label.
+- Automatically replace models when new checkpoints are loaded under the same name.
+- Batch delete all models in a collection when a node is removed.
+
+## Offloading
+
+The [`~ComponentsManager.enable_auto_cpu_offload`] method is a global offloading strategy that works across all models regardless of which pipeline is using them. Once enabled, you don't need to worry about device placement if you add or remove components.
+
+```py
+comp.enable_auto_cpu_offload(device="cuda")
+```
+
+All models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low.
+
+You can set your own rules for which models to offload first.
diff --git a/docs/source/en/modular_diffusers/custom_blocks.md b/docs/source/en/modular_diffusers/custom_blocks.md
new file mode 100644
index 000000000000..1c311582264e
--- /dev/null
+++ b/docs/source/en/modular_diffusers/custom_blocks.md
@@ -0,0 +1,492 @@
+
+
+
+# Building Custom Blocks
+
+[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
+
+> [!TIP]
+> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
+
+## Project Structure
+
+Your custom block project should use the following structure:
+
+```shell
+.
+├── block.py
+└── modular_config.json
+```
+
+- `block.py` contains the custom block implementation
+- `modular_config.json` contains the metadata needed to load the block
+
+## Example: Florence 2 Inpainting Block
+
+In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
+
+The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
+
+```py
+# Inside block.py
+from diffusers.modular_pipelines import (
+ ModularPipelineBlocks,
+ ComponentSpec,
+)
+from transformers import AutoProcessor, Florence2ForConditionalGeneration
+
+
+class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
+
+ @property
+ def expected_components(self):
+ return [
+ ComponentSpec(
+ name="image_annotator",
+ type_hint=Florence2ForConditionalGeneration,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ComponentSpec(
+ name="image_annotator_processor",
+ type_hint=AutoProcessor,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ]
+```
+
+Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
+
+```py
+from typing import List, Union
+from PIL import Image, ImageDraw
+import torch
+import numpy as np
+
+from diffusers.modular_pipelines import (
+ PipelineState,
+ ModularPipelineBlocks,
+ InputParam,
+ ComponentSpec,
+ OutputParam,
+)
+from transformers import AutoProcessor, Florence2ForConditionalGeneration
+
+
+class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
+
+ @property
+ def expected_components(self):
+ return [
+ ComponentSpec(
+ name="image_annotator",
+ type_hint=Florence2ForConditionalGeneration,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ComponentSpec(
+ name="image_annotator_processor",
+ type_hint=AutoProcessor,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "image",
+ type_hint=Union[Image.Image, List[Image.Image]],
+ required=True,
+ description="Image(s) to annotate",
+ ),
+ InputParam(
+ "annotation_task",
+ type_hint=Union[str, List[str]],
+ required=True,
+ default="",
+ description="""Annotation Task to perform on the image.
+ Supported Tasks:
+
+
+
+
+
+
+
+
+
+
+ """,
+ ),
+ InputParam(
+ "annotation_prompt",
+ type_hint=Union[str, List[str]],
+ required=True,
+ description="""Annotation Prompt to provide more context to the task.
+ Can be used to detect or segment out specific elements in the image
+ """,
+ ),
+ InputParam(
+ "annotation_output_type",
+ type_hint=str,
+ required=True,
+ default="mask_image",
+ description="""Output type from annotation predictions. Availabe options are
+ mask_image:
+ -black and white mask image for the given image based on the task type
+ mask_overlay:
+ - mask overlayed on the original image
+ bounding_box:
+ - bounding boxes drawn on the original image
+ """,
+ ),
+ InputParam(
+ "annotation_overlay",
+ type_hint=bool,
+ required=True,
+ default=False,
+ description="",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "mask_image",
+ type_hint=Image,
+ description="Inpainting Mask for input Image(s)",
+ ),
+ OutputParam(
+ "annotations",
+ type_hint=dict,
+ description="Annotations Predictions for input Image(s)",
+ ),
+ OutputParam(
+ "image",
+ type_hint=Image,
+ description="Annotated input Image(s)",
+ ),
+ ]
+
+```
+
+Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
+
+```py
+from typing import List, Union
+from PIL import Image, ImageDraw
+import torch
+import numpy as np
+
+from diffusers.modular_pipelines import (
+ PipelineState,
+ ModularPipelineBlocks,
+ InputParam,
+ ComponentSpec,
+ OutputParam,
+)
+from transformers import AutoProcessor, Florence2ForConditionalGeneration
+
+
+class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
+
+ @property
+ def expected_components(self):
+ return [
+ ComponentSpec(
+ name="image_annotator",
+ type_hint=Florence2ForConditionalGeneration,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ComponentSpec(
+ name="image_annotator_processor",
+ type_hint=AutoProcessor,
+ pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "image",
+ type_hint=Union[Image.Image, List[Image.Image]],
+ required=True,
+ description="Image(s) to annotate",
+ ),
+ InputParam(
+ "annotation_task",
+ type_hint=Union[str, List[str]],
+ required=True,
+ default="",
+ description="""Annotation Task to perform on the image.
+ Supported Tasks:
+
+
+
+
+
+
+
+
+
+
+ """,
+ ),
+ InputParam(
+ "annotation_prompt",
+ type_hint=Union[str, List[str]],
+ required=True,
+ description="""Annotation Prompt to provide more context to the task.
+ Can be used to detect or segment out specific elements in the image
+ """,
+ ),
+ InputParam(
+ "annotation_output_type",
+ type_hint=str,
+ required=True,
+ default="mask_image",
+ description="""Output type from annotation predictions. Availabe options are
+ mask_image:
+ -black and white mask image for the given image based on the task type
+ mask_overlay:
+ - mask overlayed on the original image
+ bounding_box:
+ - bounding boxes drawn on the original image
+ """,
+ ),
+ InputParam(
+ "annotation_overlay",
+ type_hint=bool,
+ required=True,
+ default=False,
+ description="",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "mask_image",
+ type_hint=Image,
+ description="Inpainting Mask for input Image(s)",
+ ),
+ OutputParam(
+ "annotations",
+ type_hint=dict,
+ description="Annotations Predictions for input Image(s)",
+ ),
+ OutputParam(
+ "image",
+ type_hint=Image,
+ description="Annotated input Image(s)",
+ ),
+ ]
+
+ def get_annotations(self, components, images, prompts, task):
+ task_prompts = [task + prompt for prompt in prompts]
+
+ inputs = components.image_annotator_processor(
+ text=task_prompts, images=images, return_tensors="pt"
+ ).to(components.image_annotator.device, components.image_annotator.dtype)
+
+ generated_ids = components.image_annotator.generate(
+ input_ids=inputs["input_ids"],
+ pixel_values=inputs["pixel_values"],
+ max_new_tokens=1024,
+ early_stopping=False,
+ do_sample=False,
+ num_beams=3,
+ )
+ annotations = components.image_annotator_processor.batch_decode(
+ generated_ids, skip_special_tokens=False
+ )
+ outputs = []
+ for image, annotation in zip(images, annotations):
+ outputs.append(
+ components.image_annotator_processor.post_process_generation(
+ annotation, task=task, image_size=(image.width, image.height)
+ )
+ )
+ return outputs
+
+ def prepare_mask(self, images, annotations, overlay=False, fill="white"):
+ masks = []
+ for image, annotation in zip(images, annotations):
+ mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
+ draw = ImageDraw.Draw(mask_image)
+
+ for _, _annotation in annotation.items():
+ if "polygons" in _annotation:
+ for polygon in _annotation["polygons"]:
+ polygon = np.array(polygon).reshape(-1, 2)
+ if len(polygon) < 3:
+ continue
+ polygon = polygon.reshape(-1).tolist()
+ draw.polygon(polygon, fill=fill)
+
+ elif "bbox" in _annotation:
+ bbox = _annotation["bbox"]
+ draw.rectangle(bbox, fill="white")
+
+ masks.append(mask_image)
+
+ return masks
+
+ def prepare_bounding_boxes(self, images, annotations):
+ outputs = []
+ for image, annotation in zip(images, annotations):
+ image_copy = image.copy()
+ draw = ImageDraw.Draw(image_copy)
+ for _, _annotation in annotation.items():
+ bbox = _annotation["bbox"]
+ label = _annotation["label"]
+
+ draw.rectangle(bbox, outline="red", width=3)
+ draw.text((bbox[0], bbox[1] - 20), label, fill="red")
+
+ outputs.append(image_copy)
+
+ return outputs
+
+ def prepare_inputs(self, images, prompts):
+ prompts = prompts or ""
+
+ if isinstance(images, Image.Image):
+ images = [images]
+ if isinstance(prompts, str):
+ prompts = [prompts]
+
+ if len(images) != len(prompts):
+ raise ValueError("Number of images and annotation prompts must match.")
+
+ return images, prompts
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ images, annotation_task_prompt = self.prepare_inputs(
+ block_state.image, block_state.annotation_prompt
+ )
+ task = block_state.annotation_task
+ fill = block_state.fill
+
+ annotations = self.get_annotations(
+ components, images, annotation_task_prompt, task
+ )
+ block_state.annotations = annotations
+ if block_state.annotation_output_type == "mask_image":
+ block_state.mask_image = self.prepare_mask(images, annotations)
+ else:
+ block_state.mask_image = None
+
+ if block_state.annotation_output_type == "mask_overlay":
+ block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)
+
+ elif block_state.annotation_output_type == "bounding_box":
+ block_state.image = self.prepare_bounding_boxes(images, annotations)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+```
+
+Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
+
+
+
+
+```shell
+# In the folder with the `block.py` file, run:
+diffusers-cli custom_block
+```
+
+Then upload the block to the Hub:
+
+```shell
+hf upload . .
+```
+
+
+
+```py
+from block import Florence2ImageAnnotatorBlock
+block = Florence2ImageAnnotatorBlock()
+block.push_to_hub("")
+```
+
+
+
+
+## Using Custom Blocks
+
+Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
+from diffusers.utils import load_image
+
+# Fetch the Florence2 image annotator block that will create our mask
+image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
+
+my_blocks = INPAINT_BLOCKS.copy()
+# insert the annotation block before the image encoding step
+my_blocks.insert("image_annotator", image_annotator_block, 1)
+
+# Create our initial set of inpainting blocks
+blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
+
+repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
+pipe = blocks.init_pipeline(repo_id)
+pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
+image = image.resize((1024, 1024))
+
+prompt = ["A red car"]
+annotation_task = ""
+annotation_prompt = ["the car"]
+
+output = pipe(
+ prompt=prompt,
+ image=image,
+ annotation_task=annotation_task,
+ annotation_prompt=annotation_prompt,
+ annotation_output_type="mask_image",
+ num_inference_steps=35,
+ guidance_scale=7.5,
+ strength=0.95,
+ output="images"
+)
+output[0].save("florence-inpainting.png")
+```
+
+## Editing Custom Blocks
+
+By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
+from diffusers.utils import load_image
+
+# Fetch the Florence2 image annotator block that will create our mask
+image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
+```
+
+Any changes made to the block files in this folder will be reflected when you load the block again.
diff --git a/docs/source/en/modular_diffusers/guiders.md b/docs/source/en/modular_diffusers/guiders.md
new file mode 100644
index 000000000000..6abe4fad2736
--- /dev/null
+++ b/docs/source/en/modular_diffusers/guiders.md
@@ -0,0 +1,175 @@
+
+
+# Guiders
+
+[Classifier-free guidance](https://huggingface.co/papers/2207.12598) steers model generation that better match a prompt and is commonly used to improve generation quality, control, and adherence to prompts. There are different types of guidance methods, and in Diffusers, they are known as *guiders*. Like blocks, it is easy to switch and use different guiders for different use cases without rewriting the pipeline.
+
+This guide will show you how to switch guiders, adjust guider parameters, and load and share them to the Hub.
+
+## Switching guiders
+
+[`ClassifierFreeGuidance`] is the default guider and created when a pipeline is initialized with [`~ModularPipelineBlocks.init_pipeline`]. It is created by `from_config` which means it doesn't require loading specifications from a modular repository. A guider won't be listed in `modular_model_index.json`.
+
+Use [`~ModularPipeline.get_component_spec`] to inspect a guider.
+
+```py
+t2i_pipeline.get_component_spec("guider")
+ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
+```
+
+Switch to a different guider by passing the new guider to [`~ModularPipeline.update_components`].
+
+> [!TIP]
+> Changing guiders will return text letting you know you're changing the guider type.
+> ```bash
+> ModularPipeline.update_components: adding guider with new type: PerturbedAttentionGuidance, previous type: ClassifierFreeGuidance
+> ```
+
+```py
+from diffusers import LayerSkipConfig, PerturbedAttentionGuidance
+
+config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_attention_scores=True, skip_ff=False)
+guider = PerturbedAttentionGuidance(
+ guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config
+)
+t2i_pipeline.update_components(guider=guider)
+```
+
+Use [`~ModularPipeline.get_component_spec`] again to verify the guider type is different.
+
+```py
+t2i_pipeline.get_component_spec("guider")
+ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
+```
+
+## Loading custom guiders
+
+Guiders that are already saved on the Hub with a `modular_model_index.json` file are considered a `from_pretrained` component now instead of a `from_config` component.
+
+```json
+{
+ "guider": [
+ null,
+ null,
+ {
+ "repo": "YiYiXu/modular-loader-t2i-guider",
+ "revision": null,
+ "subfolder": "pag_guider",
+ "type_hint": [
+ "diffusers",
+ "PerturbedAttentionGuidance"
+ ],
+ "variant": null
+ }
+ ]
+}
+```
+
+The guider is only created after calling [`~ModularPipeline.load_components`] based on the loading specification in `modular_model_index.json`.
+
+```py
+t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider")
+# not created during init
+assert t2i_pipeline.guider is None
+t2i_pipeline.load_components()
+# loaded as PAG guider
+t2i_pipeline.guider
+```
+
+
+## Changing guider parameters
+
+The guider parameters can be adjusted with either the [`~ComponentSpec.create`] method or with [`~ModularPipeline.update_components`]. The example below changes the `guidance_scale` value.
+
+
+
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider = guider_spec.create(guidance_scale=10)
+t2i_pipeline.update_components(guider=guider)
+```
+
+
+
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider_spec.config["guidance_scale"] = 10
+t2i_pipeline.update_components(guider=guider_spec)
+```
+
+
+
+
+## Uploading custom guiders
+
+Call the [`~utils.PushToHubMixin.push_to_hub`] method on a custom guider to share it to the Hub.
+
+```py
+guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
+```
+
+To make this guider available to the pipeline, either modify the `modular_model_index.json` file or use the [`~ModularPipeline.update_components`] method.
+
+
+
+
+Edit the `modular_model_index.json` file and add a loading specification for the guider by pointing to a folder containing the guider config.
+
+```json
+{
+ "guider": [
+ "diffusers",
+ "PerturbedAttentionGuidance",
+ {
+ "repo": "YiYiXu/modular-loader-t2i-guider",
+ "revision": null,
+ "subfolder": "pag_guider",
+ "type_hint": [
+ "diffusers",
+ "PerturbedAttentionGuidance"
+ ],
+ "variant": null
+ }
+ ],
+```
+
+
+
+
+Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and use [`~ModularPipeline.update_components`] to update the guider and component specifications as well as the pipeline config.
+
+> [!TIP]
+> Changing the creation method will return text letting you know you're changing the creation type to `from_pretrained`.
+> ```bash
+> ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained.
+> ```
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider_spec.default_creation_method="from_pretrained"
+guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
+guider_spec.subfolder="pag_guider"
+pag_guider = guider_spec.load()
+t2i_pipeline.update_components(guider=pag_guider)
+```
+
+To make it the default guider for a pipeline, call [`~utils.PushToHubMixin.push_to_hub`]. This is an optional step and not necessary if you are only experimenting locally.
+
+```py
+t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider")
+```
+
+
+
diff --git a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
new file mode 100644
index 000000000000..a80309de19a6
--- /dev/null
+++ b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
@@ -0,0 +1,92 @@
+
+
+# LoopSequentialPipelineBlocks
+
+[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
+
+This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
+
+## Loop wrapper
+
+[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
+
+- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
+- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
+- `__call__` method defines the loop structure and iteration logic.
+
+```py
+import torch
+from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, ModularPipelineBlocks, InputParam, OutputParam
+
+class LoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "test"
+ @property
+ def description(self):
+ return "I'm a loop!!"
+ @property
+ def loop_inputs(self):
+ return [InputParam(name="num_steps")]
+ @torch.no_grad()
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ # Loop structure - can be customized to your needs
+ for i in range(block_state.num_steps):
+ # loop_step executes all registered blocks in sequence
+ components, block_state = self.loop_step(components, block_state, i=i)
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+The loop wrapper can pass additional arguments, like current iteration index, to the loop blocks.
+
+## Loop blocks
+
+A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
+
+- It recieves the iteration variable from the loop wrapper.
+- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
+- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].
+
+Loop blocks share the same [`~modular_pipelines.BlockState`] to allow values to accumulate and change for each iteration in the loop.
+
+```py
+class LoopBlock(ModularPipelineBlocks):
+ model_name = "test"
+ @property
+ def inputs(self):
+ return [InputParam(name="x")]
+ @property
+ def intermediate_outputs(self):
+ # outputs produced by this block
+ return [OutputParam(name="x")]
+ @property
+ def description(self):
+ return "I'm a block used inside the `LoopWrapper` class"
+ def __call__(self, components, block_state, i: int):
+ block_state.x += 1
+ return components, block_state
+```
+
+## LoopSequentialPipelineBlocks
+
+Use the [`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`] method to add the loop block to the loop wrapper to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
+
+```py
+loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock})
+```
+
+Add more loop blocks to run within each iteration with [`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`]. This allows you to modify the blocks without changing the loop logic itself.
+
+```py
+loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
+```
diff --git a/docs/source/en/modular_diffusers/modular_diffusers_states.md b/docs/source/en/modular_diffusers/modular_diffusers_states.md
new file mode 100644
index 000000000000..eb55b524e491
--- /dev/null
+++ b/docs/source/en/modular_diffusers/modular_diffusers_states.md
@@ -0,0 +1,75 @@
+
+
+# States
+
+Blocks rely on the [`~modular_pipelines.PipelineState`] and [`~modular_pipelines.BlockState`] data structures for communicating and sharing data.
+
+| State | Description |
+|-------|-------------|
+| [`~modular_pipelines.PipelineState`] | Maintains the overall data required for a pipeline's execution and allows blocks to read and update its data. |
+| [`~modular_pipelines.BlockState`] | Allows each block to perform its computation with the necessary data from `inputs`|
+
+This guide explains how states work and how they connect blocks.
+
+## PipelineState
+
+The [`~modular_pipelines.PipelineState`] is a global state container for all blocks. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.
+
+There are two dict's in [`~modular_pipelines.PipelineState`] for structuring data.
+
+- The `values` dict is a **mutable** state containing a copy of user provided input values and intermediate output values generated by blocks. If a block modifies an `input`, it will be reflected in the `values` dict after calling `set_block_state`.
+
+```py
+PipelineState(
+ values={
+ 'prompt': 'a cat'
+ 'guidance_scale': 7.0
+ 'num_inference_steps': 25
+ 'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))
+ 'negative_prompt_embeds': None
+ },
+)
+```
+
+## BlockState
+
+The [`~modular_pipelines.BlockState`] is a local view of the relevant variables an individual block needs from [`~modular_pipelines.PipelineState`] for performing it's computations.
+
+Access these variables directly as attributes like `block_state.image`.
+
+```py
+BlockState(
+ image:
+)
+```
+
+When a block's `__call__` method is executed, it retrieves the [`BlockState`] with `self.get_block_state(state)`, performs it's operations, and updates [`~modular_pipelines.PipelineState`] with `self.set_block_state(state, block_state)`.
+
+```py
+def __call__(self, components, state):
+ # retrieve BlockState
+ block_state = self.get_block_state(state)
+
+ # computation logic on inputs
+
+ # update PipelineState
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+## State interaction
+
+[`~modular_pipelines.PipelineState`] and [`~modular_pipelines.BlockState`] interaction is defined by a block's `inputs`, and `intermediate_outputs`.
+
+- `inputs`, a block can modify an input - like `block_state.image` - and this change can be propagated globally to [`~modular_pipelines.PipelineState`] by calling `set_block_state`.
+- `intermediate_outputs`, is a new variable that a block creates. It is added to the [`~modular_pipelines.PipelineState`]'s `values` dict and is available as for subsequent blocks or accessed by users as a final output from the pipeline.
diff --git a/docs/source/en/modular_diffusers/modular_pipeline.md b/docs/source/en/modular_diffusers/modular_pipeline.md
new file mode 100644
index 000000000000..34cd8f72b5b7
--- /dev/null
+++ b/docs/source/en/modular_diffusers/modular_pipeline.md
@@ -0,0 +1,358 @@
+
+
+# ModularPipeline
+
+[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`]'s into an executable pipeline that loads models and performs the computation steps defined in the block. It is the main interface for running a pipeline and it is very similar to the [`DiffusionPipeline`] API.
+
+The main difference is to include an expected `output` argument in the pipeline.
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
+image.save("modular_t2i_out.png")
+```
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
+
+blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+init_image = load_image(url)
+prompt = "a dog catching a frisbee in the jungle"
+image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0]
+image.save("modular_i2i_out.png")
+```
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
+from diffusers.utils import load_image
+
+blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png"
+
+init_image = load_image(img_url)
+mask_image = load_image(mask_url)
+
+prompt = "A deep sea diver floating"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0]
+image.save("moduar_inpaint_out.png")
+```
+
+
+
+
+This guide will show you how to create a [`ModularPipeline`] and manage the components in it.
+
+## Adding blocks
+
+Blocks are [`InsertableDict`] objects that can be inserted at specific positions, providing a flexible way to mix-and-match blocks.
+
+Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.insert`] on either the block class or `sub_blocks` attribute to add a block.
+
+```py
+# BLOCKS is dict of block classes, you need to add class to it
+BLOCKS.insert("block_name", BlockClass, index)
+# sub_blocks attribute contains instance, add a block instance to the attribute
+t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
+```
+
+Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.pop`] on either the block class or `sub_blocks` attribute to remove a block.
+
+```py
+# remove a block class from preset
+BLOCKS.pop("text_encoder")
+# split out a block instance on its own
+text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
+```
+
+Swap blocks by setting the existing block to the new block.
+
+```py
+# Replace block class in preset
+BLOCKS["prepare_latents"] = CustomPrepareLatents
+# Replace in sub_blocks attribute using an block instance
+t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
+```
+
+## Creating a pipeline
+
+There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
+
+You should also initialize a [`ComponentsManager`] to handle device placement and memory and component management.
+
+> [!TIP]
+> Refer to the [ComponentsManager](./components_manager) doc for more details about how it can help manage components across different workflows.
+
+
+
+
+Use the [`~ModularPipelineBlocks.init_pipeline`] method to create a [`ModularPipeline`] from the component and configuration specifications. This method loads the *specifications* from a `modular_model_index.json` file, but it doesn't load the *models* yet.
+
+```py
+from diffusers import ComponentsManager
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+components = ComponentsManager()
+t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
+```
+
+
+
+
+The [`~ModularPipeline.from_pretrained`] method creates a [`ModularPipeline`] from a modular repository on the Hub.
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
+```
+
+Add the `trust_remote_code` argument to load a custom [`ModularPipeline`].
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+modular_repo_id = "YiYiXu/modular-diffdiff-0704"
+diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
+```
+
+
+
+
+## Loading components
+
+A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load all components with [`~ModularPipeline.load_components`] or only load specific components with [`~ModularPipeline.load_components`].
+
+
+
+
+```py
+import torch
+
+t2i_pipeline.load_components(torch_dtype=torch.float16)
+t2i_pipeline.to("cuda")
+```
+
+
+
+
+The example below only loads the UNet and VAE.
+
+```py
+import torch
+
+t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
+```
+
+
+
+
+Print the pipeline to inspect the loaded pretrained components.
+
+```py
+t2i_pipeline
+```
+
+This should match the `modular_model_index.json` file from the modular repository a pipeline is initialized from. If a pipeline doesn't need a component, it won't be included even if it exists in the modular repository.
+
+To modify where components are loaded from, edit the `modular_model_index.json` file in the repository and change it to your desired loading path. The example below loads a UNet from a different repository.
+
+```json
+# original
+"unet": [
+ null, null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "subfolder": "unet",
+ "variant": "fp16"
+ }
+]
+
+# modified
+"unet": [
+ null, null,
+ {
+ "repo": "RunDiffusion/Juggernaut-XL-v9",
+ "subfolder": "unet",
+ "variant": "fp16"
+ }
+]
+```
+
+### Component loading status
+
+The pipeline properties below provide more information about which components are loaded.
+
+Use `component_names` to return all expected components.
+
+```py
+t2i_pipeline.component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
+```
+
+Use `null_component_names` to return components that aren't loaded yet. Load these components with [`~ModularPipeline.from_pretrained`].
+
+```py
+t2i_pipeline.null_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
+```
+
+Use `pretrained_component_names` to return components that will be loaded from pretrained models.
+
+```py
+t2i_pipeline.pretrained_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
+```
+
+Use `config_component_names` to return components that are created with the default config (not loaded from a modular repository). Components from a config aren't included because they are already initialized during pipeline creation. This is why they aren't listed in `null_component_names`.
+
+```py
+t2i_pipeline.config_component_names
+['guider', 'image_processor']
+```
+
+## Updating components
+
+Components may be updated depending on whether it is a *pretrained component* or a *config component*.
+
+> [!WARNING]
+> A component may change from pretrained to config when updating a component. The component type is initially defined in a block's `expected_components` field.
+
+A pretrained component is updated with [`ComponentSpec`] whereas a config component is updated by eihter passing the object directly or with [`ComponentSpec`].
+
+The [`ComponentSpec`] shows `default_creation_method="from_pretrained"` for a pretrained component shows `default_creation_method="from_config` for a config component.
+
+To update a pretrained component, create a [`ComponentSpec`] with the name of the component and where to load it from. Use the [`~ComponentSpec.load`] method to load the component.
+
+```py
+from diffusers import ComponentSpec, UNet2DConditionModel
+
+unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
+unet = unet_spec.load(torch_dtype=torch.float16)
+```
+
+The [`~ModularPipeline.update_components`] method replaces the component with a new one.
+
+```py
+t2i_pipeline.update_components(unet=unet2)
+```
+
+When a component is updated, the loading specifications are also updated in the pipeline config.
+
+### Component extraction and modification
+
+When you use [`~ComponentSpec.load`], the new component maintains its loading specifications. This makes it possible to extract the specification and recreate the component.
+
+```py
+spec = ComponentSpec.from_component("unet", unet2)
+spec
+ComponentSpec(name='unet', type_hint=, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
+unet2_recreated = spec.load(torch_dtype=torch.float16)
+```
+
+The [`~ModularPipeline.get_component_spec`] method gets a copy of the current component specification to modify or update.
+
+```py
+unet_spec = t2i_pipeline.get_component_spec("unet")
+unet_spec
+ComponentSpec(
+ name='unet',
+ type_hint=,
+ pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
+ subfolder='unet',
+ variant='fp16',
+ default_creation_method='from_pretrained'
+)
+
+# modify to load from a different repository
+unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
+
+# load component with modified spec
+unet = unet_spec.load(torch_dtype=torch.float16)
+```
+
+## Modular repository
+
+A repository is required if the pipeline blocks use *pretrained components*. The repository supplies loading specifications and metadata.
+
+[`ModularPipeline`] specifically requires *modular repositories* (see [example repository](https://huggingface.co/YiYiXu/modular-diffdiff)) which are more flexible than a typical repository. It contains a `modular_model_index.json` file containing the following 3 elements.
+
+- `library` and `class` shows which library the component was loaded from and it's class. If `null`, the component hasn't been loaded yet.
+- `loading_specs_dict` contains the information required to load the component such as the repository and subfolder it is loaded from.
+
+Unlike standard repositories, a modular repository can fetch components from different repositories based on the `loading_specs_dict`. Components don't need to exist in the same repository.
+
+A modular repository may contain custom code for loading a [`ModularPipeline`]. This allows you to use specialized blocks that aren't native to Diffusers.
+
+```
+modular-diffdiff-0704/
+├── block.py # Custom pipeline blocks implementation
+├── config.json # Pipeline configuration and auto_map
+└── modular_model_index.json # Component loading specifications
+```
+
+The [config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file contains an `auto_map` key that points to where a custom block is defined in `block.py`.
+
+```json
+{
+ "_class_name": "DiffDiffBlocks",
+ "auto_map": {
+ "ModularPipelineBlocks": "block.DiffDiffBlocks"
+ }
+}
+```
diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md
new file mode 100644
index 000000000000..7d07c4b73434
--- /dev/null
+++ b/docs/source/en/modular_diffusers/overview.md
@@ -0,0 +1,41 @@
+
+
+# Overview
+
+> [!WARNING]
+> Modular Diffusers is under active development and it's API may change.
+
+Modular Diffusers is a unified pipeline system that simplifies your workflow with *pipeline blocks*.
+
+- Blocks are reusable and you only need to create new blocks that are unique to your pipeline.
+- Blocks can be mixed and matched to adapt to or create a pipeline for a specific workflow or multiple workflows.
+
+The Modular Diffusers docs are organized as shown below.
+
+## Quickstart
+
+- A [quickstart](./quickstart) demonstrating how to implement an example workflow with Modular Diffusers.
+
+## ModularPipelineBlocks
+
+- [States](./modular_diffusers_states) explains how data is shared and communicated between blocks and [`ModularPipeline`].
+- [ModularPipelineBlocks](./pipeline_block) is the most basic unit of a [`ModularPipeline`] and this guide shows you how to create one.
+- [SequentialPipelineBlocks](./sequential_pipeline_blocks) is a type of block that chains multiple blocks so they run one after another, passing data along the chain. This guide shows you how to create [`~modular_pipelines.SequentialPipelineBlocks`] and how they connect and work together.
+- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks) is a type of block that runs a series of blocks in a loop. This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
+- [AutoPipelineBlocks](./auto_pipeline_blocks) is a type of block that automatically chooses which blocks to run based on the input. This guide shows you how to create [`~modular_pipelines.AutoPipelineBlocks`].
+
+## ModularPipeline
+
+- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].
+- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.
+- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/pipeline_block.md b/docs/source/en/modular_diffusers/pipeline_block.md
new file mode 100644
index 000000000000..06c115e1fb52
--- /dev/null
+++ b/docs/source/en/modular_diffusers/pipeline_block.md
@@ -0,0 +1,105 @@
+
+
+# ModularPipelineBlocks
+
+[`~modular_pipelines.ModularPipelineBlocks`] is the basic block for building a [`ModularPipeline`]. It defines what components, inputs/outputs, and computation a block should perform for a specific step in a pipeline. A [`~modular_pipelines.ModularPipelineBlocks`] connects with other blocks, using [state](./modular_diffusers_states), to enable the modular construction of workflows.
+
+A [`~modular_pipelines.ModularPipelineBlocks`] on it's own can't be executed. It is a blueprint for what a step should do in a pipeline. To actually run and execute a pipeline, the [`~modular_pipelines.ModularPipelineBlocks`] needs to be converted into a [`ModularPipeline`].
+
+This guide will show you how to create a [`~modular_pipelines.ModularPipelineBlocks`].
+
+## Inputs and outputs
+
+> [!TIP]
+> Refer to the [States](./modular_diffusers_states) guide if you aren't familiar with how state works in Modular Diffusers.
+
+A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermediate_outputs`.
+
+- `inputs` are values provided by a user and retrieved from the [`~modular_pipelines.PipelineState`]. This is useful because some workflows resize an image, but the original image is still required. The [`~modular_pipelines.PipelineState`] maintains the original image.
+
+ Use `InputParam` to define `inputs`.
+
+ ```py
+ from diffusers.modular_pipelines import InputParam
+
+ user_inputs = [
+ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
+ ]
+ ```
+
+- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
+
+ Use `OutputParam` to define `intermediate_outputs`.
+
+ ```py
+ from diffusers.modular_pipelines import OutputParam
+
+ user_intermediate_outputs = [
+ OutputParam(name="image_latents", description="latents representing the image")
+ ]
+ ```
+
+The intermediate inputs and outputs share data to connect blocks. They are accessible at any point, allowing you to track the workflow's progress.
+
+## Computation logic
+
+The computation a block performs is defined in the `__call__` method and it follows a specific structure.
+
+1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
+2. Implement the computation logic on the `inputs`.
+3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
+4. Return the components and state which becomes available to the next block.
+
+```py
+def __call__(self, components, state):
+ # Get a local view of the state variables this block needs
+ block_state = self.get_block_state(state)
+
+ # Your computation logic here
+ # block_state contains all your inputs
+ # Access them like: block_state.image, block_state.processed_image
+
+ # Update the pipeline state with your updated block_states
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+### Components and configs
+
+The components and pipeline-level configs a block needs are specified in [`ComponentSpec`] and [`~modular_pipelines.ConfigSpec`].
+
+- [`ComponentSpec`] contains the expected components used by a block. You need the `name` of the component and ideally a `type_hint` that specifies exactly what the component is.
+- [`~modular_pipelines.ConfigSpec`] contains pipeline-level settings that control behavior across all blocks.
+
+```py
+from diffusers import ComponentSpec, ConfigSpec
+
+expected_components = [
+ ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
+ ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
+]
+
+expected_config = [
+ ConfigSpec("force_zeros_for_empty_prompt", True)
+]
+```
+
+When the blocks are converted into a pipeline, the components become available to the block as the first argument in `__call__`.
+
+```py
+def __call__(self, components, state):
+ # Access components using dot notation
+ unet = components.unet
+ vae = components.vae
+ scheduler = components.scheduler
+```
diff --git a/docs/source/en/modular_diffusers/quickstart.md b/docs/source/en/modular_diffusers/quickstart.md
new file mode 100644
index 000000000000..32d14d84e243
--- /dev/null
+++ b/docs/source/en/modular_diffusers/quickstart.md
@@ -0,0 +1,344 @@
+
+
+# Quickstart
+
+Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface developers can use.
+
+This doc will show you how to implement a [Differential Diffusion](https://differential-diffusion.github.io/) pipeline with the modular framework.
+
+## ModularPipelineBlocks
+
+[`ModularPipelineBlocks`] are *definitions* that specify the components, inputs, outputs, and computation logic for a single step in a pipeline. There are four types of blocks.
+
+- [`ModularPipelineBlocks`] is the most basic block for a single step.
+- [`SequentialPipelineBlocks`] is a multi-block that composes other blocks linearly. The outputs of one block are the inputs to the next block.
+- [`LoopSequentialPipelineBlocks`] is a multi-block that runs iteratively and is designed for iterative workflows.
+- [`AutoPipelineBlocks`] is a collection of blocks for different workflows and it selects which block to run based on the input. It is designed to conveniently package multiple workflows into a single pipeline.
+
+[Differential Diffusion](https://differential-diffusion.github.io/) is an image-to-image workflow. Start with the `IMAGE2IMAGE_BLOCKS` preset, a collection of `ModularPipelineBlocks` for image-to-image generation.
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
+IMAGE2IMAGE_BLOCKS = InsertableDict([
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("image_encoder", StableDiffusionXLVaeEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLDenoiseStep),
+ ("decode", StableDiffusionXLDecodeStep)
+])
+```
+
+## Pipeline and block states
+
+Modular Diffusers uses *state* to communicate data between blocks. There are two types of states.
+
+- [`PipelineState`] is a global state that can be used to track all inputs and outputs across all blocks.
+- [`BlockState`] is a local view of relevant variables from [`PipelineState`] for an individual block.
+
+## Customizing blocks
+
+[Differential Diffusion](https://differential-diffusion.github.io/) differs from standard image-to-image in its `prepare_latents` and `denoise` blocks. All the other blocks can be reused, but you'll need to modify these two.
+
+Create placeholder `ModularPipelineBlocks` for `prepare_latents` and `denoise` by copying and modifying the existing ones.
+
+Print the `denoise` block to see that it is composed of [`LoopSequentialPipelineBlocks`] with three sub-blocks, `before_denoiser`, `denoiser`, and `after_denoiser`. Only the `before_denoiser` sub-block needs to be modified to prepare the latent input for the denoiser based on the change map.
+
+```py
+denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
+print(denoise_blocks)
+```
+
+Replace the `StableDiffusionXLLoopBeforeDenoiser` sub-block with the new `SDXLDiffDiffLoopBeforeDenoiser` block.
+
+```py
+# Copy existing blocks as placeholders
+class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
+ """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
+ # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
+
+class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+```
+
+### prepare_latents
+
+The `prepare_latents` block requires the following changes.
+
+- a processor to process the change map
+- a new `inputs` to accept the user-provided change map, `timestep` for precomputing all the latents and `num_inference_steps` to create the mask for updating the image regions
+- update the computation in the `__call__` method for processing the change map and creating the masks, and storing it in the [`BlockState`]
+
+```diff
+class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
++ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
+ ]
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("generator"),
++ InputParam("diffdiff_map", required=True),
+- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
++ InputParam("timesteps", type_hint=torch.Tensor),
++ InputParam("num_inference_steps", type_hint=int),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
++ OutputParam("original_latents", type_hint=torch.Tensor),
++ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
+ ]
+ def __call__(self, components, state: PipelineState):
+ # ... existing logic ...
++ # Process change map and create masks
++ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
++ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
++ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
++ block_state.original_latents = block_state.latents
+```
+
+### denoise
+
+The `before_denoiser` sub-block requires the following changes.
+
+- a new `inputs` to accept a `denoising_start` parameter, `original_latents` and `diffdiff_masks` from the `prepare_latents` block
+- update the computation in the `__call__` method for applying Differential Diffusion
+
+```diff
+class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ return [
+ InputParam("latents", required=True, type_hint=torch.Tensor),
++ InputParam("denoising_start"),
++ InputParam("original_latents", type_hint=torch.Tensor),
++ InputParam("diffdiff_masks", type_hint=torch.Tensor),
+ ]
+
+ def __call__(self, components, block_state, i, t):
++ # Apply differential diffusion logic
++ if i == 0 and block_state.denoising_start is None:
++ block_state.latents = block_state.original_latents[:1]
++ else:
++ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
++ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
+
+ # ... rest of existing logic ...
+```
+
+## Assembling the blocks
+
+You should have all the blocks you need at this point to create a [`ModularPipeline`].
+
+Copy the existing `IMAGE2IMAGE_BLOCKS` preset and for the `set_timesteps` block, use the `set_timesteps` from the `TEXT2IMAGE_BLOCKS` because Differential Diffusion doesn't require a `strength` parameter.
+
+Set the `prepare_latents` and `denoise` blocks to the `SDXLDiffDiffPrepareLatentsStep` and `SDXLDiffDiffDenoiseStep` blocks you just modified.
+
+Call [`SequentialPipelineBlocks.from_blocks_dict`] on the blocks to create a `SequentialPipelineBlocks`.
+
+```py
+DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
+DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
+DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
+DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
+
+dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
+print(dd_blocks)
+```
+
+## ModularPipeline
+
+Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_components`].
+
+It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization.
+
+```py
+from diffusers.modular_pipelines import ComponentsManager
+
+components = ComponentManager()
+
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
+dd_pipeline.load_componenets(torch_dtype=torch.float16)
+dd_pipeline.to("cuda")
+```
+
+## Adding workflows
+
+Other workflows can be added to the [`ModularPipeline`] to support additional features without rewriting the entire pipeline from scratch.
+
+This section demonstrates how to add an IP-Adapter or ControlNet.
+
+### IP-Adapter
+
+Stable Diffusion XL already has a preset IP-Adapter block that you can use and doesn't require any changes to the existing Differential Diffusion pipeline.
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
+
+ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
+```
+
+Use the [`sub_blocks.insert`] method to insert it into the [`ModularPipeline`]. The example below inserts the `ip_adapter_block` at position `0`. Print the pipeline to see that the `ip_adapter_block` is added and it requires an `ip_adapter_image`. This also added two components to the pipeline, the `image_encoder` and `feature_extractor`.
+
+```py
+dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
+```
+
+Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_components`] to load the model components. Load and set the IP-Adapter to run the pipeline.
+
+```py
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
+dd_pipeline.loader.set_ip_adapter_scale(0.6)
+dd_pipeline = dd_pipeline.to(device)
+
+ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
+image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+
+prompt = "a green pear"
+negative_prompt = "blurry"
+generator = torch.Generator(device=device).manual_seed(42)
+
+image = dd_pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=25,
+ generator=generator,
+ ip_adapter_image=ip_adapter_image,
+ diffdiff_map=mask,
+ image=image,
+ output="images"
+)[0]
+```
+
+### ControlNet
+
+Stable Diffusion XL already has a preset ControlNet block that can readily be used.
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
+
+control_input_block = StableDiffusionXLAutoControlNetInputStep()
+```
+
+However, it requires modifying the `denoise` block because that's where the ControlNet injects the control information into the UNet.
+
+Modify the `denoise` block by replacing the `StableDiffusionXLLoopDenoiser` sub-block with the `StableDiffusionXLControlNetLoopDenoiser`.
+
+```py
+class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
+```
+
+Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_components`] into it.
+
+```py
+dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
+dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
+
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+dd_pipeline = dd_pipeline.to(device)
+
+control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
+image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+
+prompt = "a green pear"
+negative_prompt = "blurry"
+generator = torch.Generator(device=device).manual_seed(42)
+
+image = dd_pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=25,
+ generator=generator,
+ control_image=control_image,
+ controlnet_conditioning_scale=0.5,
+ diffdiff_map=mask,
+ image=image,
+ output="images"
+)[0]
+```
+
+### AutoPipelineBlocks
+
+The Differential Diffusion, IP-Adapter, and ControlNet workflows can be bundled into a single [`ModularPipeline`] by using [`AutoPipelineBlocks`]. This allows automatically selecting which sub-blocks to run based on the inputs like `control_image` or `ip_adapter_image`. If none of these inputs are passed, then it defaults to the Differential Diffusion.
+
+Use `block_trigger_inputs` to only run the `SDXLDiffDiffControlNetDenoiseStep` block if a `control_image` input is provided. Otherwise, the `SDXLDiffDiffDenoiseStep` is used.
+
+```py
+class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
+ block_names = ["controlnet_denoise", "denoise"]
+ block_trigger_inputs = ["controlnet_cond", None]
+```
+
+Add the `ip_adapter` and `controlnet_input` blocks.
+
+```py
+DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
+DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
+DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
+DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
+DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
+DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
+```
+
+Call [`SequentialPipelineBlocks.from_blocks_dict`] to create a [`SequentialPipelineBlocks`] and create a [`ModularPipeline`] and load in the model components to run.
+
+```py
+dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
+dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+```
+
+## Share
+
+Add your [`ModularPipeline`] to the Hub with [`~ModularPipeline.save_pretrained`] and set `push_to_hub` argument to `True`.
+
+```py
+dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
+```
+
+Other users can load the [`ModularPipeline`] with [`~ModularPipeline.from_pretrained`].
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+
+diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
+diffdiff_pipeline.load_components(torch_dtype=torch.float16)
+```
diff --git a/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md
new file mode 100644
index 000000000000..f1549a26b86f
--- /dev/null
+++ b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md
@@ -0,0 +1,113 @@
+
+
+# SequentialPipelineBlocks
+
+[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
+
+This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
+
+Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
+
+
+
+
+```py
+from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
+
+class InputBlock(ModularPipelineBlocks):
+
+ @property
+ def inputs(self):
+ return [
+ InputParam(name="prompt", type_hint=list, description="list of text prompts"),
+ InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt"),
+ ]
+
+ @property
+ def intermediate_outputs(self):
+ return [
+ OutputParam(name="batch_size", description="calculated batch size"),
+ ]
+
+ @property
+ def description(self):
+ return "A block that determines batch_size based on the number of prompts and num_images_per_prompt argument."
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ batch_size = len(block_state.prompt)
+ block_state.batch_size = batch_size * block_state.num_images_per_prompt
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
+
+class ImageEncoderBlock(ModularPipelineBlocks):
+
+ @property
+ def inputs(self):
+ return [
+ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process"),
+ InputParam(name="batch_size", type_hint=int),
+ ]
+
+ @property
+ def intermediate_outputs(self):
+ return [
+ OutputParam(name="image_latents", description="latents representing the image"),
+ ]
+
+ @property
+ def description(self):
+ return "Encode raw image into its latent presentation"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ # Simulate processing the image
+ # This will change the state of the image from a PIL image to a tensor for all blocks
+ block_state.image = torch.randn(1, 3, 512, 512)
+ block_state.batch_size = block_state.batch_size * 2
+ block_state.image_latents = torch.randn(1, 4, 64, 64)
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+Connect the two blocks by defining an [`InsertableDict`] to map the block names to the block instances. Blocks are executed in the order they're registered in `blocks_dict`.
+
+Use [`~modular_pipelines.SequentialPipelineBlocks.from_blocks_dict`] to create a [`~modular_pipelines.SequentialPipelineBlocks`].
+
+```py
+from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
+
+blocks_dict = InsertableDict()
+blocks_dict["input"] = input_block
+blocks_dict["image_encoder"] = image_encoder_block
+
+blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
+```
+
+Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by calling `blocks`, and for more details about the inputs and outputs, access the `docs` attribute.
+
+```py
+print(blocks)
+print(blocks.doc)
+```
diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md
new file mode 100644
index 000000000000..f3ff4781c6ec
--- /dev/null
+++ b/docs/source/en/optimization/attention_backends.md
@@ -0,0 +1,159 @@
+
+
+# Attention backends
+
+> [!NOTE]
+> The attention dispatcher is an experimental feature. Please open an issue if you have any feedback or encounter any problems.
+
+Diffusers provides several optimized attention algorithms that are more memory and computationally efficient through it's *attention dispatcher*. The dispatcher acts as a router for managing and switching between different attention implementations and provides a unified interface for interacting with them.
+
+Refer to the table below for an overview of the available attention families and to the [Available backends](#available-backends) section for a more complete list.
+
+| attention family | main feature |
+|---|---|
+| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
+| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
+| SageAttention | quantizes attention to int8 |
+| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
+| xFormers | memory-efficient attention with support for various attention kernels |
+
+This guide will show you how to set and use the different attention backends.
+
+## set_attention_backend
+
+The [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.
+
+The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [`kernels`](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
+
+> [!NOTE]
+> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
+
+```py
+import torch
+from diffusers import QwenImagePipeline
+
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+pipeline.transformer.set_attention_backend("_flash_3_hub")
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+```
+
+To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
+
+```py
+pipeline.transformer.reset_attention_backend()
+```
+
+## attention_backend context manager
+
+The [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager temporarily sets an attention backend for a model within the context. Outside the context, the default attention (PyTorch's native scaled dot product attention) is used. This is useful if you want to use different backends for different parts of a pipeline or if you want to test the different backends.
+
+```py
+import torch
+from diffusers import QwenImagePipeline
+
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+
+with attention_backend("_flash_3_hub"):
+ image = pipeline(prompt).images[0]
+```
+
+> [!TIP]
+> Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference.
+
+## Checks
+
+The attention dispatcher includes debugging checks that catch common errors before they cause problems.
+
+1. Device checks verify that query, key, and value tensors live on the same device.
+2. Data type checks confirm tensors have matching dtypes and use either bfloat16 or float16.
+3. Shape checks validate tensor dimensions and prevent mixing attention masks with causal flags.
+
+Enable these checks by setting the `DIFFUSERS_ATTN_CHECKS` environment variable. Checks add overhead to every attention operation, so they're disabled by default.
+
+```bash
+export DIFFUSERS_ATTN_CHECKS=yes
+```
+
+The checks are run now before every attention operation.
+
+```py
+import torch
+
+query = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
+key = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
+value = torch.randn(1, 10, 8, 64, dtype=torch.bfloat16, device="cuda")
+
+try:
+ with attention_backend("flash"):
+ output = dispatch_attention_fn(query, key, value)
+ print("✓ Flash Attention works with checks enabled")
+except Exception as e:
+ print(f"✗ Flash Attention failed: {e}")
+```
+
+You can also configure the registry directly.
+
+```py
+from diffusers.models.attention_dispatch import _AttentionBackendRegistry
+
+_AttentionBackendRegistry._checks_enabled = True
+```
+
+## Available backends
+
+Refer to the table below for a complete list of available attention backends and their variants.
+
+
+Expand
+
+| Backend Name | Family | Description |
+|--------------|--------|-------------|
+| `native` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Default backend using PyTorch's scaled_dot_product_attention |
+| `flex` | [FlexAttention](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) | PyTorch FlexAttention implementation |
+| `_native_cudnn` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | CuDNN-optimized attention |
+| `_native_efficient` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Memory-efficient attention |
+| `_native_flash` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | PyTorch's FlashAttention |
+| `_native_math` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Math-based attention (fallback) |
+| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
+| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
+| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
+| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
+| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
+| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
+| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
+| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
+| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
+| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
+| `_flash_3_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 from kernels |
+| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
+| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
+| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
+| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
+| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
+| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |
+| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
+| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |
+
+
diff --git a/docs/source/en/optimization/cache.md b/docs/source/en/optimization/cache.md
new file mode 100644
index 000000000000..6397c7d4cd2e
--- /dev/null
+++ b/docs/source/en/optimization/cache.md
@@ -0,0 +1,100 @@
+
+
+# Caching
+
+Caching accelerates inference by storing and reusing intermediate outputs of different layers, such as attention and feedforward layers, instead of performing the entire computation at each inference step. It significantly improves generation speed at the expense of more memory and doesn't require additional training.
+
+This guide shows you how to use the caching methods supported in Diffusers.
+
+## Pyramid Attention Broadcast
+
+[Pyramid Attention Broadcast (PAB)](https://huggingface.co/papers/2408.12588) is based on the observation that attention outputs aren't that different between successive timesteps of the generation process. The attention differences are smallest in the cross attention layers and are generally cached over a longer timestep range. This is followed by temporal attention and spatial attention layers.
+
+> [!TIP]
+> Not all video models have three types of attention (cross, temporal, and spatial)!
+
+PAB can be combined with other techniques like sequence parallelism and classifier-free guidance parallelism (data parallelism) for near real-time video generation.
+
+Set up and pass a [`PyramidAttentionBroadcastConfig`] to a pipeline's transformer to enable it. The `spatial_attention_block_skip_range` controls how often to skip attention calculations in the spatial attention blocks and the `spatial_attention_timestep_skip_range` is the range of timesteps to skip. Take care to choose an appropriate range because a smaller interval can lead to slower inference speeds and a larger interval can result in lower generation quality.
+
+```python
+import torch
+from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
+
+pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+pipeline.to("cuda")
+
+config = PyramidAttentionBroadcastConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(100, 800),
+ current_timestep_callback=lambda: pipe.current_timestep,
+)
+pipeline.transformer.enable_cache(config)
+```
+
+## FasterCache
+
+[FasterCache](https://huggingface.co/papers/2410.19355) caches and reuses attention features similar to [PAB](#pyramid-attention-broadcast) since output differences are small for each successive timestep.
+
+This method may also choose to skip the unconditional branch prediction, when using classifier-free guidance for sampling (common in most base models), and estimate it from the conditional branch prediction if there is significant redundancy in the predicted latent outputs between successive timesteps.
+
+Set up and pass a [`FasterCacheConfig`] to a pipeline's transformer to enable it.
+
+```python
+import torch
+from diffusers import CogVideoXPipeline, FasterCacheConfig
+
+pipe line= CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+pipeline.to("cuda")
+
+config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 681),
+ current_timestep_callback=lambda: pipe.current_timestep,
+ attention_weight_callback=lambda _: 0.3,
+ unconditional_batch_skip_range=5,
+ unconditional_batch_timestep_skip_range=(-1, 781),
+ tensor_format="BFCHW",
+)
+pipeline.transformer.enable_cache(config)
+```
+
+## TaylorSeer Cache
+
+[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
+
+This caching mechanism delivers strong results with minimal additional memory overhead. For detailed performance analysis, see [our findings here](https://github.com/huggingface/diffusers/pull/12648#issuecomment-3610615080).
+
+To enable TaylorSeer Cache, create a [`TaylorSeerCacheConfig`] and pass it to your pipeline's transformer:
+
+- `cache_interval`: Number of steps to reuse cached outputs before performing a full forward pass
+- `disable_cache_before_step`: Initial steps that use full computations to gather data for approximations
+- `max_order`: Approximation accuracy (in theory, higher values improve quality but increase memory usage but we recommend it should be set to `1`)
+
+```python
+import torch
+from diffusers import FluxPipeline, TaylorSeerCacheConfig
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+)
+pipe.to("cuda")
+
+config = TaylorSeerCacheConfig(
+ cache_interval=5,
+ max_order=1,
+ disable_cache_before_step=10,
+ taylor_factors_dtype=torch.bfloat16,
+)
+pipe.transformer.enable_cache(config)
+```
\ No newline at end of file
diff --git a/docs/source/en/optimization/cache_dit.md b/docs/source/en/optimization/cache_dit.md
new file mode 100644
index 000000000000..126142321249
--- /dev/null
+++ b/docs/source/en/optimization/cache_dit.md
@@ -0,0 +1,270 @@
+## CacheDiT
+
+CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
+
+To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
+
+Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
+
+
+
+
+```bash
+pip3 install -U cache-dit
+```
+
+
+
+
+```bash
+pip3 install git+https://github.com/vipshop/cache-dit.git
+```
+
+
+
+
+Run the command below to view supported DiT pipelines.
+
+```python
+>>> import cache_dit
+>>> cache_dit.supported_pipelines()
+(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
+'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
+'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
+'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
+```
+
+For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
+
+
+## Unified Cache API
+
+CacheDiT works by matching specific input/output patterns as shown below.
+
+
+
+Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
+
+```python
+import cache_dit
+from diffusers import DiffusionPipeline
+
+# Can be any diffusion pipeline
+pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
+
+# One-line code with default cache options.
+cache_dit.enable_cache(pipe)
+
+# Just call the pipe as normal.
+output = pipe(...)
+
+# Disable cache and run original pipe.
+cache_dit.disable_cache(pipe)
+```
+
+## Automatic Block Adapter
+
+For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
+
+
+```python
+from cache_dit import ForwardPattern, BlockAdapter
+
+# Use 🔥BlockAdapter with `auto` mode.
+cache_dit.enable_cache(
+ BlockAdapter(
+ # Any DiffusionPipeline, Qwen-Image, etc.
+ pipe=pipe, auto=True,
+ # Check `📚Forward Pattern Matching` documentation and hack the code of
+ # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
+ forward_pattern=ForwardPattern.Pattern_1,
+ ),
+)
+
+# Or, manually setup transformer configurations.
+cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe, # Qwen-Image, etc.
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_1,
+ ),
+)
+```
+
+Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
+Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
+
+```python
+# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
+# single_transformer_blocks have different forward patterns.
+cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe, # FLUX.1, etc.
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_3,
+ ],
+ ),
+)
+```
+
+This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
+
+## Patch Functor
+
+For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.
+
+
+
+Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
+
+```python
+@BlockAdapterRegistry.register("HiDream")
+def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import HiDreamImageTransformer2DModel
+ from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
+
+ assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.double_stream_blocks,
+ pipe.transformer.single_stream_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_3,
+ ],
+ # NOTE: Setup your custom patch functor here.
+ patch_functor=HiDreamPatchFunctor(),
+ **kwargs,
+ )
+```
+
+Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
+
+```python
+stats = cache_dit.summary(pipe)
+```
+
+```python
+⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
+
+| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
+|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
+| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
+```
+
+## DBCache: Dual Block Cache
+
+
+
+DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
+- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
+- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
+
+
+```python
+import cache_dit
+from diffusers import FluxPipeline
+
+pipe_or_adapter = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# Default options, F8B0, 8 warmup steps, and unlimited cached
+# steps for good balance between performance and precision
+cache_dit.enable_cache(pipe_or_adapter)
+
+# Custom options, F8B8, higher precision
+from cache_dit import BasicCacheConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=BasicCacheConfig(
+ max_warmup_steps=8, # steps do not cache
+ max_cached_steps=-1, # -1 means no limit
+ Fn_compute_blocks=8, # Fn, F8, etc.
+ Bn_compute_blocks=8, # Bn, B8, etc.
+ residual_diff_threshold=0.12,
+ ),
+)
+```
+Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
+
+## TaylorSeer Calibrator
+
+The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
+
+TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
+
+```python
+from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ # Basic DBCache w/ FnBn configurations
+ cache_config=BasicCacheConfig(
+ max_warmup_steps=8, # steps do not cache
+ max_cached_steps=-1, # -1 means no limit
+ Fn_compute_blocks=8, # Fn, F8, etc.
+ Bn_compute_blocks=8, # Bn, B8, etc.
+ residual_diff_threshold=0.12,
+ ),
+ # Then, you can use the TaylorSeer Calibrator to approximate
+ # the values in cached steps, taylorseer_order default is 1.
+ calibrator_config=TaylorSeerCalibratorConfig(
+ taylorseer_order=1,
+ ),
+)
+```
+
+> [!TIP]
+> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
+
+## Hybrid Cache CFG
+
+CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
+
+```python
+from cache_dit import BasicCacheConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=BasicCacheConfig(
+ ...,
+ # For example, set it as True for Wan 2.1, Qwen-Image
+ # and set it as False for FLUX.1, HunyuanVideo, etc.
+ enable_separate_cfg=True,
+ ),
+)
+```
+
+## torch.compile
+
+CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
+
+
+```python
+cache_dit.enable_cache(pipe)
+
+# Compile the Transformer module
+pipe.transformer = torch.compile(pipe.transformer)
+```
+
+If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
+
+```python
+torch._dynamo.config.recompile_limit = 96 # default is 8
+torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
+```
+
+Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
diff --git a/docs/source/en/optimization/coreml.md b/docs/source/en/optimization/coreml.md
index d090ef0ed3ba..71da1e3dc1fe 100644
--- a/docs/source/en/optimization/coreml.md
+++ b/docs/source/en/optimization/coreml.md
@@ -1,4 +1,4 @@
-
-# Speed up inference
+# Accelerate inference
-There are several ways to optimize Diffusers for inference speed, such as reducing the computational burden by lowering the data precision or using a lightweight distilled model. There are also memory-efficient attention implementations, [xFormers](xformers) and [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) in PyTorch 2.0, that reduce memory usage which also indirectly speeds up inference. Different speed optimizations can be stacked together to get the fastest inference times.
+Diffusion models are slow at inference because generation is an iterative process where noise is gradually refined into an image or video over a certain number of "steps". To speedup this process, you can try experimenting with different [schedulers](../api/schedulers/overview), reduce the precision of the model weights for faster computations, use more memory-efficient attention mechanisms, and more.
-> [!TIP]
-> Optimizing for inference speed or reduced memory usage can lead to improved performance in the other category, so you should try to optimize for both whenever you can. This guide focuses on inference speed, but you can learn more about lowering memory usage in the [Reduce memory usage](memory) guide.
+Combine and use these techniques together to make inference faster than using any single technique on its own.
+
+This guide will go over how to accelerate inference.
+
+## Model data type
+
+The precision and data type of the model weights affect inference speed because a higher precision requires more memory to load and more time to perform the computations. PyTorch loads model weights in float32 or full precision by default, so changing the data type is a simple way to quickly get faster inference.
+
+
+
+
+bfloat16 is similar to float16 but it is more robust to numerical errors. Hardware support for bfloat16 varies, but most modern GPUs are capable of supporting bfloat16.
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
-The inference times below are obtained from generating a single 512x512 image from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps on a NVIDIA A100.
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
+).to("cuda")
-| setup | latency | speed-up |
-|----------|---------|----------|
-| baseline | 5.27s | x1 |
-| tf32 | 4.14s | x1.27 |
-| fp16 | 3.51s | x1.50 |
-| combined | 3.41s | x1.54 |
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
+```
-## TensorFloat-32
+
+
-On Ampere and later CUDA devices, matrix multiplications and convolutions can use the [TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode for faster, but slightly less accurate computations. By default, PyTorch enables tf32 mode for convolutions but not matrix multiplications. Unless your network requires full float32 precision, we recommend enabling tf32 for matrix multiplications. It can significantly speed up computations with typically negligible loss in numerical accuracy.
+float16 is similar to bfloat16 but may be more prone to numerical errors.
-```python
+```py
import torch
+from diffusers import StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
+```
+
+
+
+
+[TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) mode is supported on NVIDIA Ampere GPUs and it computes the convolution and matrix multiplication operations in tf32. Storage and other operations are kept in float32. This enables significantly faster computations when combined with bfloat16 or float16.
+
+PyTorch only enables tf32 mode for convolutions by default and you'll need to explicitly enable it for matrix multiplications.
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
torch.backends.cuda.matmul.allow_tf32 = True
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
```
-Learn more about tf32 in the [Mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#tf32) guide.
+Refer to the [mixed precision training](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#mixed-precision) docs for more details.
+
+
+
-## Half-precision weights
+## Scaled dot product attention
-To save GPU memory and get more speed, set `torch_dtype=torch.float16` to load and run the model weights directly with half-precision weights.
+> [!TIP]
+> Memory-efficient attention optimizes for inference speed *and* [memory usage](./memory#memory-efficient-attention)!
-```Python
+[Scaled dot product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) implements several attention backends, [FlashAttention](https://github.com/Dao-AILab/flash-attention), [xFormers](https://github.com/facebookresearch/xformers), and a native C++ implementation. It automatically selects the most optimal backend for your hardware.
+
+SDPA is enabled by default if you're using PyTorch >= 2.0 and no additional changes are required to your code. You could try experimenting with other attention backends though if you'd like to choose your own. The example below uses the [torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) context manager to enable efficient attention.
+
+```py
+from torch.nn.attention import SDPBackend, sdpa_kernel
import torch
-from diffusers import DiffusionPipeline
+from diffusers import StableDiffusionXLPipeline
-pipe = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
- use_safetensors=True,
-)
-pipe = pipe.to("cuda")
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
+ image = pipeline(prompt, num_inference_steps=30).images[0]
```
-> [!WARNING]
-> Don't use [torch.autocast](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than pure float16 precision.
+## torch.compile
+
+[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) accelerates inference by compiling PyTorch code and operations into optimized kernels. Diffusers typically compiles the more compute-intensive models like the UNet, transformer, or VAE.
+
+Enable the following compiler settings for maximum speed (refer to the [full list](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py) for more options).
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
-## Distilled model
+torch._inductor.config.conv_1x1_as_mm = True
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.epilogue_fusion = False
+torch._inductor.config.coordinate_descent_check_all_directions = True
+```
-You could also use a distilled Stable Diffusion model and autoencoder to speed up inference. During distillation, many of the UNet's residual and attention blocks are shed to reduce the model size by 51% and improve latency on CPU/GPU by 43%. The distilled model is faster and uses less memory while generating images of comparable quality to the full Stable Diffusion model.
+Load and compile the UNet and VAE. There are several different modes you can choose from, but `"max-autotune"` optimizes for the fastest speed by compiling to a CUDA graph. CUDA graphs effectively reduces the overhead by launching multiple GPU operations through a single CPU operation.
> [!TIP]
-> Read the [Open-sourcing Knowledge Distillation Code and Weights of SD-Small and SD-Tiny](https://huggingface.co/blog/sd_distillation) blog post to learn more about how knowledge distillation training works to produce a faster, smaller, and cheaper generative model.
+> With PyTorch 2.3.1, you can control the caching behavior of torch.compile. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
-The inference times below are obtained from generating 4 images from the prompt "a photo of an astronaut riding a horse on mars" with 25 PNDM steps on a NVIDIA A100. Each generation is repeated 3 times with the distilled Stable Diffusion v1.4 model by [Nota AI](https://hf.co/nota-ai).
+Changing the memory layout to [channels_last](./memory#torchchannels_last) also optimizes memory and inference speed.
+
+```py
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+).to("cuda")
+pipeline.unet.to(memory_format=torch.channels_last)
+pipeline.vae.to(memory_format=torch.channels_last)
+pipeline.unet = torch.compile(
+ pipeline.unet, mode="max-autotune", fullgraph=True
+)
+pipeline.vae.decode = torch.compile(
+ pipeline.vae.decode,
+ mode="max-autotune",
+ fullgraph=True
+)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
+```
+
+Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.
+
+### Dynamic shape compilation
+
+> [!TIP]
+> Make sure to always use the nightly version of PyTorch for better support.
-| setup | latency | speed-up |
-|------------------------------|---------|----------|
-| baseline | 6.37s | x1 |
-| distilled | 4.18s | x1.52 |
-| distilled + tiny autoencoder | 3.83s | x1.66 |
+`torch.compile` keeps track of input shapes and conditions, and if these are different, it recompiles the model. For example, if a model is compiled on a 1024x1024 resolution image and used on an image with a different resolution, it triggers recompilation.
-Let's load the distilled Stable Diffusion model and compare it against the original Stable Diffusion model.
+To avoid recompilation, add `dynamic=True` to try and generate a more dynamic kernel to avoid recompilation when conditions change.
+
+```diff
++ torch.fx.experimental._config.use_duck_shape = False
++ pipeline.unet = torch.compile(
+ pipeline.unet, fullgraph=True, dynamic=True
+)
+```
+
+Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
+
+Not all models may benefit from dynamic compilation out of the box and may require changes. Refer to this [PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the [`AuraFlowPipeline`] implementation to benefit from dynamic compilation.
+
+Feel free to open an issue if dynamic compilation doesn't work as expected for a Diffusers model.
+
+### Regional compilation
+
+[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
+For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 8–10x.
+
+Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
```py
-from diffusers import StableDiffusionPipeline
+# pip install -U diffusers
import torch
+from diffusers import StableDiffusionXLPipeline
-distilled = StableDiffusionPipeline.from_pretrained(
- "nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
).to("cuda")
-prompt = "a golden vase with different flowers"
-generator = torch.manual_seed(2023)
-image = distilled("a golden vase with different flowers", num_inference_steps=25, generator=generator).images[0]
-image
+
+# compile only the repeated transformer layers inside the UNet
+pipeline.unet.compile_repeated_blocks(fullgraph=True)
```
-
-
-
-
original Stable Diffusion
-
-
-
-
distilled Stable Diffusion
-
-
+To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
+
+```py
+class MyUNet(ModelMixin):
+ _repeated_blocks = ("Transformer2DModel",) # ← compiled by default
+```
-### Tiny AutoEncoder
+> [!TIP]
+> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
-To speed inference up even more, replace the autoencoder with a [distilled version](https://huggingface.co/sayakpaul/taesdxl-diffusers) of it.
+There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
```py
+# pip install -U accelerate
import torch
-from diffusers import AutoencoderTiny, StableDiffusionPipeline
+from diffusers import StableDiffusionXLPipeline
+from accelerate.utils import compile_regions
-distilled = StableDiffusionPipeline.from_pretrained(
- "nota-ai/bk-sdm-small", torch_dtype=torch.float16, use_safetensors=True,
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
-distilled.vae = AutoencoderTiny.from_pretrained(
- "sayakpaul/taesd-diffusers", torch_dtype=torch.float16, use_safetensors=True,
+pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
+
+### Graph breaks
+
+It is important to specify `fullgraph=True` in torch.compile to ensure there are no graph breaks in the underlying model. This allows you to take advantage of torch.compile without any performance degradation. For the UNet and VAE, this changes how you access the return variables.
+
+```diff
+- latents = unet(
+- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
+-).sample
+
++ latents = unet(
++ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
++)[0]
+```
+
+### GPU sync
+
+The `step()` function is [called](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) on the scheduler each time after the denoiser makes a prediction, and the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476). When placed on the GPU, it introduces latency because of the communication sync between the CPU and GPU. It becomes more evident when the denoiser has already been compiled.
+
+In general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.
+
+> [!TIP]
+> Refer to the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post for maximizing performance with `torch.compile` for diffusion models.
+
+### Benchmarks
+
+Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/benchmarks) dataset to see inference latency and memory usage data for compiled pipelines.
+
+The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.
+
+## Dynamic quantization
+
+[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.
+
+The example below applies [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to the UNet and VAE with the [torchao](../quantization/torchao) library.
+
+> [!TIP]
+> Refer to our [torchao](../quantization/torchao) docs to learn more about how to use the Diffusers torchao integration.
+
+Configure the compiler tags for maximum speed.
+
+```py
+import torch
+from torchao import apply_dynamic_quant
+from diffusers import StableDiffusionXLPipeline
+
+torch._inductor.config.conv_1x1_as_mm = True
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.epilogue_fusion = False
+torch._inductor.config.coordinate_descent_check_all_directions = True
+torch._inductor.config.force_fuse_int_mm_with_mul = True
+torch._inductor.config.use_mixed_mm = True
+```
+
+Filter out some linear layers in the UNet and VAE which don't benefit from dynamic quantization with the [dynamic_quant_filter_fn](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16).
+
+```py
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")
-prompt = "a golden vase with different flowers"
-generator = torch.manual_seed(2023)
-image = distilled("a golden vase with different flowers", num_inference_steps=25, generator=generator).images[0]
-image
+apply_dynamic_quant(pipeline.unet, dynamic_quant_filter_fn)
+apply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
```
-
-
-
-
distilled Stable Diffusion + Tiny AutoEncoder
-
-
+## Fused projection matrices
+
+> [!WARNING]
+> The [fuse_qkv_projections](https://github.com/huggingface/diffusers/blob/58431f102cf39c3c8a569f32d71b2ea8caa461e1/src/diffusers/pipelines/pipeline_utils.py#L2034) method is experimental and support is limited to mostly Stable Diffusion pipelines. Take a look at this [PR](https://github.com/huggingface/diffusers/pull/6179) to learn more about how to enable it for other pipelines
+
+An input is projected into three subspaces, represented by the projection matrices Q, K, and V, in an attention block. These projections are typically calculated separately, but you can horizontally combine these into a single matrix and perform the projection in a single step. It increases the size of the matrix multiplications of the input projections and also improves the impact of quantization.
+
+```py
+pipeline.fuse_qkv_projections()
+```
+
+## Resources
+
+- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
-More tiny autoencoder models for other Stable Diffusion models, like Stable Diffusion 3, are available from [madebyollin](https://huggingface.co/madebyollin).
\ No newline at end of file
+ These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).
+- Read the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post
+to maximize performance when using `torch.compile`.
\ No newline at end of file
diff --git a/docs/source/en/optimization/habana.md b/docs/source/en/optimization/habana.md
index 86a0cf0ba019..1e5563ae101e 100644
--- a/docs/source/en/optimization/habana.md
+++ b/docs/source/en/optimization/habana.md
@@ -1,4 +1,4 @@
-
-# Habana Gaudi
+# Intel Gaudi
-🤗 Diffusers is compatible with Habana Gaudi through 🤗 [Optimum](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion). Follow the [installation](https://docs.habana.ai/en/latest/Installation_Guide/index.html) guide to install the SynapseAI and Gaudi drivers, and then install Optimum Habana:
+The Intel Gaudi AI accelerator family includes [Intel Gaudi 1](https://habana.ai/products/gaudi/), [Intel Gaudi 2](https://habana.ai/products/gaudi2/), and [Intel Gaudi 3](https://habana.ai/products/gaudi3/). Each server is equipped with 8 devices, known as Habana Processing Units (HPUs), providing 128GB of memory on Gaudi 3, 96GB on Gaudi 2, and 32GB on the first-gen Gaudi. For more details on the underlying hardware architecture, check out the [Gaudi Architecture](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html) overview.
-```bash
-python -m pip install --upgrade-strategy eager optimum[habana]
-```
-
-To generate images with Stable Diffusion 1 and 2 on Gaudi, you need to instantiate two instances:
-
-- [`~optimum.habana.diffusers.GaudiStableDiffusionPipeline`], a pipeline for text-to-image generation.
-- [`~optimum.habana.diffusers.GaudiDDIMScheduler`], a Gaudi-optimized scheduler.
-
-When you initialize the pipeline, you have to specify `use_habana=True` to deploy it on HPUs and to get the fastest possible generation, you should enable **HPU graphs** with `use_hpu_graphs=True`.
+Diffusers pipelines can take advantage of HPU acceleration, even if a pipeline hasn't been added to [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index) yet, with the [GPU Migration Toolkit](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Model_Porting/GPU_Migration_Toolkit/GPU_Migration_Toolkit.html).
-Finally, specify a [`~optimum.habana.GaudiConfig`] which can be downloaded from the [Habana](https://huggingface.co/Habana) organization on the Hub.
-
-```python
-from optimum.habana import GaudiConfig
-from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline
-
-model_name = "stabilityai/stable-diffusion-2-base"
-scheduler = GaudiDDIMScheduler.from_pretrained(model_name, subfolder="scheduler")
-pipeline = GaudiStableDiffusionPipeline.from_pretrained(
- model_name,
- scheduler=scheduler,
- use_habana=True,
- use_hpu_graphs=True,
- gaudi_config="Habana/stable-diffusion-2",
-)
-```
+Call `.to("hpu")` on your pipeline to move it to a HPU device as shown below for Flux:
+```py
+import torch
+from diffusers import DiffusionPipeline
-Now you can call the pipeline to generate images by batches from one or several prompts:
+pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+pipeline.to("hpu")
-```python
-outputs = pipeline(
- prompt=[
- "High quality photo of an astronaut riding a horse in space",
- "Face of a yellow cat, high resolution, sitting on a park bench",
- ],
- num_images_per_prompt=10,
- batch_size=4,
-)
+image = pipeline("An image of a squirrel in Picasso style").images[0]
```
-For more information, check out 🤗 Optimum Habana's [documentation](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion) and the [example](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) provided in the official GitHub repository.
-
-## Benchmark
-
-We benchmarked Habana's first-generation Gaudi and Gaudi2 with the [Habana/stable-diffusion](https://huggingface.co/Habana/stable-diffusion) and [Habana/stable-diffusion-2](https://huggingface.co/Habana/stable-diffusion-2) Gaudi configurations (mixed precision bf16/fp32) to demonstrate their performance.
-
-For [Stable Diffusion v1.5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) on 512x512 images:
-
-| | Latency (batch size = 1) | Throughput |
-| ---------------------- |:------------------------:|:---------------------------:|
-| first-generation Gaudi | 3.80s | 0.308 images/s (batch size = 8) |
-| Gaudi2 | 1.33s | 1.081 images/s (batch size = 8) |
-
-For [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) on 768x768 images:
-
-| | Latency (batch size = 1) | Throughput |
-| ---------------------- |:------------------------:|:-------------------------------:|
-| first-generation Gaudi | 10.2s | 0.108 images/s (batch size = 4) |
-| Gaudi2 | 3.17s | 0.379 images/s (batch size = 8) |
+> [!TIP]
+> For Gaudi-optimized diffusion pipeline implementations, we recommend using [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index).
diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md
index fd72957471c0..611e07ec7655 100644
--- a/docs/source/en/optimization/memory.md
+++ b/docs/source/en/optimization/memory.md
@@ -1,4 +1,4 @@
-
+
+# Compiling and offloading quantized models
+
+Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
+
+> [!TIP]
+> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
+
+For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
+
+For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
+
+The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage for Flux.
+
+| combination | latency (s) | memory-usage (GB) |
+|---|---|---|
+| quantization | 32.602 | 14.9453 |
+| quantization, torch.compile | 25.847 | 14.9448 |
+| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
+
+These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the benchmarking script if you're interested in evaluating your own model.
+
+This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
+
+```bash
+pip install -U bitsandbytes
+```
+
+## Quantization and torch.compile
+
+Start by [quantizing](../quantization/overview) a model to reduce the memory required for storage and [compiling](./fp16#torchcompile) it to accelerate inference.
+
+Configure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+
+torch._dynamo.config.capture_dynamic_output_shape_ops = True
+
+# quantize
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder_2"],
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# compile
+pipeline.transformer.to(memory_format=torch.channels_last)
+pipeline.transformer.compile(mode="max-autotune", fullgraph=True)
+pipeline("""
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+).images[0]
+```
+
+## Quantization, torch.compile, and offloading
+
+In addition to quantization and torch.compile, try offloading if you need to reduce memory-usage further. Offloading moves various layers or model components from the CPU to the GPU as needed for computations.
+
+Configure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `cache_size_limit` during offloading to avoid excessive recompilation and set `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.
+
+
+
+
+[Model CPU offloading](./memory#model-offloading) moves an individual pipeline component, like the transformer model, to the GPU when it is needed for computation. Otherwise, it is offloaded to the CPU.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+
+torch._dynamo.config.cache_size_limit = 1000
+torch._dynamo.config.capture_dynamic_output_shape_ops = True
+
+# quantize
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder_2"],
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# model CPU offloading
+pipeline.enable_model_cpu_offload()
+
+# compile
+pipeline.transformer.compile()
+pipeline(
+ "cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
+).images[0]
+```
+
+
+
+
+[Group offloading](./memory#group-offloading) moves the internal layers of an individual pipeline component, like the transformer model, to the GPU for computation and offloads it when it's not required. At the same time, it uses the [CUDA stream](./memory#cuda-stream) feature to prefetch the next layer for execution.
+
+By overlapping computation and data transfer, it is faster than model CPU offloading while also saving memory.
+
+```py
+# pip install ftfy
+import torch
+from diffusers import AutoModel, DiffusionPipeline
+from diffusers.hooks import apply_group_offloading
+from diffusers.utils import export_to_video
+from diffusers.quantizers import PipelineQuantizationConfig
+from transformers import UMT5EncoderModel
+
+torch._dynamo.config.cache_size_limit = 1000
+torch._dynamo.config.capture_dynamic_output_shape_ops = True
+
+# quantize
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder"],
+)
+
+text_encoder = UMT5EncoderModel.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-14B-Diffusers",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# group offloading
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+
+pipeline.transformer.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True,
+ non_blocking=True
+)
+pipeline.vae.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True,
+ non_blocking=True
+)
+apply_group_offloading(
+ pipeline.text_encoder,
+ onload_device=onload_device,
+ offload_type="leaf_level",
+ use_stream=True,
+ non_blocking=True
+)
+
+# compile
+pipeline.transformer.compile()
+
+prompt = """
+The camera rushes from far to near in a low-angle shot,
+revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
+"""
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=81,
+ guidance_scale=5.0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
+```
+
+
+
\ No newline at end of file
diff --git a/docs/source/en/optimization/tome.md b/docs/source/en/optimization/tome.md
index 3e574efbfe1b..ab368c9ccbb9 100644
--- a/docs/source/en/optimization/tome.md
+++ b/docs/source/en/optimization/tome.md
@@ -1,4 +1,4 @@
-
-
-# PyTorch 2.0
-
-🤗 Diffusers supports the latest optimizations from [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) which include:
-
-1. A memory-efficient attention implementation, scaled dot product attention, without requiring any extra dependencies such as xFormers.
-2. [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), a just-in-time (JIT) compiler to provide an extra performance boost when individual models are compiled.
-
-Both of these optimizations require PyTorch 2.0 or later and 🤗 Diffusers > 0.13.0.
-
-```bash
-pip install --upgrade torch diffusers
-```
-
-## Scaled dot product attention
-
-[`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) (SDPA) is an optimized and memory-efficient attention (similar to xFormers) that automatically enables several other optimizations depending on the model inputs and GPU type. SDPA is enabled by default if you're using PyTorch 2.0 and the latest version of 🤗 Diffusers, so you don't need to add anything to your code.
-
-However, if you want to explicitly enable it, you can set a [`DiffusionPipeline`] to use [`~models.attention_processor.AttnProcessor2_0`]:
-
-```diff
- import torch
- from diffusers import DiffusionPipeline
-+ from diffusers.models.attention_processor import AttnProcessor2_0
-
- pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
-+ pipe.unet.set_attn_processor(AttnProcessor2_0())
-
- prompt = "a photo of an astronaut riding a horse on mars"
- image = pipe(prompt).images[0]
-```
-
-SDPA should be as fast and memory efficient as `xFormers`; check the [benchmark](#benchmark) for more details.
-
-In some cases - such as making the pipeline more deterministic or converting it to other formats - it may be helpful to use the vanilla attention processor, [`~models.attention_processor.AttnProcessor`]. To revert to [`~models.attention_processor.AttnProcessor`], call the [`~UNet2DConditionModel.set_default_attn_processor`] function on the pipeline:
-
-```diff
- import torch
- from diffusers import DiffusionPipeline
-
- pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
-+ pipe.unet.set_default_attn_processor()
-
- prompt = "a photo of an astronaut riding a horse on mars"
- image = pipe(prompt).images[0]
-```
-
-## torch.compile
-
-The `torch.compile` function can often provide an additional speed-up to your PyTorch code. In 🤗 Diffusers, it is usually best to wrap the UNet with `torch.compile` because it does most of the heavy lifting in the pipeline.
-
-```python
-from diffusers import DiffusionPipeline
-import torch
-
-pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
-pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images[0]
-```
-
-Depending on GPU type, `torch.compile` can provide an *additional speed-up* of **5-300x** on top of SDPA! If you're using more recent GPU architectures such as Ampere (A100, 3090), Ada (4090), and Hopper (H100), `torch.compile` is able to squeeze even more performance out of these GPUs.
-
-Compilation requires some time to complete, so it is best suited for situations where you prepare your pipeline once and then perform the same type of inference operations multiple times. For example, calling the compiled pipeline on a different image size triggers compilation again which can be expensive.
-
-For more information and different options about `torch.compile`, refer to the [`torch_compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) tutorial.
-
-> [!TIP]
-> Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial.
-
-## Benchmark
-
-We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).
-
-Expand the dropdown below to find the code used to benchmark each pipeline:
-
-
-
-### Stable Diffusion text-to-image
-
-```python
-from diffusers import DiffusionPipeline
-import torch
-
-path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-
-run_compile = True # Set True / False
-
-pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
-pipe = pipe.to("cuda")
-pipe.unet.to(memory_format=torch.channels_last)
-
-if run_compile:
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
-prompt = "ghibli style, a fantasy landscape with castles"
-
-for _ in range(3):
- images = pipe(prompt=prompt).images
-```
-
-### Stable Diffusion image-to-image
-
-```python
-from diffusers import StableDiffusionImg2ImgPipeline
-from diffusers.utils import load_image
-import torch
-
-url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
-
-init_image = load_image(url)
-init_image = init_image.resize((512, 512))
-
-path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-
-run_compile = True # Set True / False
-
-pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
-pipe = pipe.to("cuda")
-pipe.unet.to(memory_format=torch.channels_last)
-
-if run_compile:
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
-prompt = "ghibli style, a fantasy landscape with castles"
-
-for _ in range(3):
- image = pipe(prompt=prompt, image=init_image).images[0]
-```
-
-### Stable Diffusion inpainting
-
-```python
-from diffusers import StableDiffusionInpaintPipeline
-from diffusers.utils import load_image
-import torch
-
-img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
-mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
-
-init_image = load_image(img_url).resize((512, 512))
-mask_image = load_image(mask_url).resize((512, 512))
-
-path = "runwayml/stable-diffusion-inpainting"
-
-run_compile = True # Set True / False
-
-pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16, use_safetensors=True)
-pipe = pipe.to("cuda")
-pipe.unet.to(memory_format=torch.channels_last)
-
-if run_compile:
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
-prompt = "ghibli style, a fantasy landscape with castles"
-
-for _ in range(3):
- image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
-```
-
-### ControlNet
-
-```python
-from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
-from diffusers.utils import load_image
-import torch
-
-url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
-
-init_image = load_image(url)
-init_image = init_image.resize((512, 512))
-
-path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-
-run_compile = True # Set True / False
-controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16, use_safetensors=True)
-pipe = StableDiffusionControlNetPipeline.from_pretrained(
- path, controlnet=controlnet, torch_dtype=torch.float16, use_safetensors=True
-)
-
-pipe = pipe.to("cuda")
-pipe.unet.to(memory_format=torch.channels_last)
-pipe.controlnet.to(memory_format=torch.channels_last)
-
-if run_compile:
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
-
-prompt = "ghibli style, a fantasy landscape with castles"
-
-for _ in range(3):
- image = pipe(prompt=prompt, image=init_image).images[0]
-```
-
-### DeepFloyd IF text-to-image + upscaling
-
-```python
-from diffusers import DiffusionPipeline
-import torch
-
-run_compile = True # Set True / False
-
-pipe_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
-pipe_1.to("cuda")
-pipe_2 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
-pipe_2.to("cuda")
-pipe_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, use_safetensors=True)
-pipe_3.to("cuda")
-
-
-pipe_1.unet.to(memory_format=torch.channels_last)
-pipe_2.unet.to(memory_format=torch.channels_last)
-pipe_3.unet.to(memory_format=torch.channels_last)
-
-if run_compile:
- pipe_1.unet = torch.compile(pipe_1.unet, mode="reduce-overhead", fullgraph=True)
- pipe_2.unet = torch.compile(pipe_2.unet, mode="reduce-overhead", fullgraph=True)
- pipe_3.unet = torch.compile(pipe_3.unet, mode="reduce-overhead", fullgraph=True)
-
-prompt = "the blue hulk"
-
-prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
-neg_prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
-
-for _ in range(3):
- image_1 = pipe_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
- image_2 = pipe_2(image=image_1, prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
- image_3 = pipe_3(prompt=prompt, image=image_1, noise_level=100).images
-```
-
-
-The graph below highlights the relative speed-ups for the [`StableDiffusionPipeline`] across five GPU families with PyTorch 2.0 and `torch.compile` enabled. The benchmarks for the following graphs are measured in *number of iterations/second*.
-
-
-
-To give you an even better idea of how this speed-up holds for the other pipelines, consider the following
-graph for an A100 with PyTorch 2.0 and `torch.compile`:
-
-
-
-In the following tables, we report our findings in terms of the *number of iterations/second*.
-
-### A100 (batch size: 1)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 21.66 | 23.13 | 44.03 | 49.74 |
-| SD - img2img | 21.81 | 22.40 | 43.92 | 46.32 |
-| SD - inpaint | 22.24 | 23.23 | 43.76 | 49.25 |
-| SD - controlnet | 15.02 | 15.82 | 32.13 | 36.08 |
-| IF | 20.21 / 13.84 / 24.00 | 20.12 / 13.70 / 24.03 | ❌ | 97.34 / 27.23 / 111.66 |
-| SDXL - txt2img | 8.64 | 9.9 | - | - |
-
-### A100 (batch size: 4)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 11.6 | 13.12 | 14.62 | 17.27 |
-| SD - img2img | 11.47 | 13.06 | 14.66 | 17.25 |
-| SD - inpaint | 11.67 | 13.31 | 14.88 | 17.48 |
-| SD - controlnet | 8.28 | 9.38 | 10.51 | 12.41 |
-| IF | 25.02 | 18.04 | ❌ | 48.47 |
-| SDXL - txt2img | 2.44 | 2.74 | - | - |
-
-### A100 (batch size: 16)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 3.04 | 3.6 | 3.83 | 4.68 |
-| SD - img2img | 2.98 | 3.58 | 3.83 | 4.67 |
-| SD - inpaint | 3.04 | 3.66 | 3.9 | 4.76 |
-| SD - controlnet | 2.15 | 2.58 | 2.74 | 3.35 |
-| IF | 8.78 | 9.82 | ❌ | 16.77 |
-| SDXL - txt2img | 0.64 | 0.72 | - | - |
-
-### V100 (batch size: 1)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 18.99 | 19.14 | 20.95 | 22.17 |
-| SD - img2img | 18.56 | 19.18 | 20.95 | 22.11 |
-| SD - inpaint | 19.14 | 19.06 | 21.08 | 22.20 |
-| SD - controlnet | 13.48 | 13.93 | 15.18 | 15.88 |
-| IF | 20.01 / 9.08 / 23.34 | 19.79 / 8.98 / 24.10 | ❌ | 55.75 / 11.57 / 57.67 |
-
-### V100 (batch size: 4)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 5.96 | 5.89 | 6.83 | 6.86 |
-| SD - img2img | 5.90 | 5.91 | 6.81 | 6.82 |
-| SD - inpaint | 5.99 | 6.03 | 6.93 | 6.95 |
-| SD - controlnet | 4.26 | 4.29 | 4.92 | 4.93 |
-| IF | 15.41 | 14.76 | ❌ | 22.95 |
-
-### V100 (batch size: 16)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 1.66 | 1.66 | 1.92 | 1.90 |
-| SD - img2img | 1.65 | 1.65 | 1.91 | 1.89 |
-| SD - inpaint | 1.69 | 1.69 | 1.95 | 1.93 |
-| SD - controlnet | 1.19 | 1.19 | OOM after warmup | 1.36 |
-| IF | 5.43 | 5.29 | ❌ | 7.06 |
-
-### T4 (batch size: 1)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 6.9 | 6.95 | 7.3 | 7.56 |
-| SD - img2img | 6.84 | 6.99 | 7.04 | 7.55 |
-| SD - inpaint | 6.91 | 6.7 | 7.01 | 7.37 |
-| SD - controlnet | 4.89 | 4.86 | 5.35 | 5.48 |
-| IF | 17.42 / 2.47 / 18.52 | 16.96 / 2.45 / 18.69 | ❌ | 24.63 / 2.47 / 23.39 |
-| SDXL - txt2img | 1.15 | 1.16 | - | - |
-
-### T4 (batch size: 4)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 1.79 | 1.79 | 2.03 | 1.99 |
-| SD - img2img | 1.77 | 1.77 | 2.05 | 2.04 |
-| SD - inpaint | 1.81 | 1.82 | 2.09 | 2.09 |
-| SD - controlnet | 1.34 | 1.27 | 1.47 | 1.46 |
-| IF | 5.79 | 5.61 | ❌ | 7.39 |
-| SDXL - txt2img | 0.288 | 0.289 | - | - |
-
-### T4 (batch size: 16)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 2.34s | 2.30s | OOM after 2nd iteration | 1.99s |
-| SD - img2img | 2.35s | 2.31s | OOM after warmup | 2.00s |
-| SD - inpaint | 2.30s | 2.26s | OOM after 2nd iteration | 1.95s |
-| SD - controlnet | OOM after 2nd iteration | OOM after 2nd iteration | OOM after warmup | OOM after warmup |
-| IF * | 1.44 | 1.44 | ❌ | 1.94 |
-| SDXL - txt2img | OOM | OOM | - | - |
-
-### RTX 3090 (batch size: 1)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 22.56 | 22.84 | 23.84 | 25.69 |
-| SD - img2img | 22.25 | 22.61 | 24.1 | 25.83 |
-| SD - inpaint | 22.22 | 22.54 | 24.26 | 26.02 |
-| SD - controlnet | 16.03 | 16.33 | 17.38 | 18.56 |
-| IF | 27.08 / 9.07 / 31.23 | 26.75 / 8.92 / 31.47 | ❌ | 68.08 / 11.16 / 65.29 |
-
-### RTX 3090 (batch size: 4)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 6.46 | 6.35 | 7.29 | 7.3 |
-| SD - img2img | 6.33 | 6.27 | 7.31 | 7.26 |
-| SD - inpaint | 6.47 | 6.4 | 7.44 | 7.39 |
-| SD - controlnet | 4.59 | 4.54 | 5.27 | 5.26 |
-| IF | 16.81 | 16.62 | ❌ | 21.57 |
-
-### RTX 3090 (batch size: 16)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 1.7 | 1.69 | 1.93 | 1.91 |
-| SD - img2img | 1.68 | 1.67 | 1.93 | 1.9 |
-| SD - inpaint | 1.72 | 1.71 | 1.97 | 1.94 |
-| SD - controlnet | 1.23 | 1.22 | 1.4 | 1.38 |
-| IF | 5.01 | 5.00 | ❌ | 6.33 |
-
-### RTX 4090 (batch size: 1)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 40.5 | 41.89 | 44.65 | 49.81 |
-| SD - img2img | 40.39 | 41.95 | 44.46 | 49.8 |
-| SD - inpaint | 40.51 | 41.88 | 44.58 | 49.72 |
-| SD - controlnet | 29.27 | 30.29 | 32.26 | 36.03 |
-| IF | 69.71 / 18.78 / 85.49 | 69.13 / 18.80 / 85.56 | ❌ | 124.60 / 26.37 / 138.79 |
-| SDXL - txt2img | 6.8 | 8.18 | - | - |
-
-### RTX 4090 (batch size: 4)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 12.62 | 12.84 | 15.32 | 15.59 |
-| SD - img2img | 12.61 | 12,.79 | 15.35 | 15.66 |
-| SD - inpaint | 12.65 | 12.81 | 15.3 | 15.58 |
-| SD - controlnet | 9.1 | 9.25 | 11.03 | 11.22 |
-| IF | 31.88 | 31.14 | ❌ | 43.92 |
-| SDXL - txt2img | 2.19 | 2.35 | - | - |
-
-### RTX 4090 (batch size: 16)
-
-| **Pipeline** | **torch 2.0 - no compile** | **torch nightly - no compile** | **torch 2.0 - compile** | **torch nightly - compile** |
-|:---:|:---:|:---:|:---:|:---:|
-| SD - txt2img | 3.17 | 3.2 | 3.84 | 3.85 |
-| SD - img2img | 3.16 | 3.2 | 3.84 | 3.85 |
-| SD - inpaint | 3.17 | 3.2 | 3.85 | 3.85 |
-| SD - controlnet | 2.23 | 2.3 | 2.7 | 2.75 |
-| IF | 9.26 | 9.2 | ❌ | 13.31 |
-| SDXL - txt2img | 0.52 | 0.53 | - | - |
-
-## Notes
-
-* Follow this [PR](https://github.com/huggingface/diffusers/pull/3313) for more details on the environment used for conducting the benchmarks.
-* For the DeepFloyd IF pipeline where batch sizes > 1, we only used a batch size of > 1 in the first IF pipeline for text-to-image generation and NOT for upscaling. That means the two upscaling pipelines received a batch size of 1.
-
-*Thanks to [Horace He](https://github.com/Chillee) from the PyTorch team for their support in improving our support of `torch.compile()` in Diffusers.*
diff --git a/docs/source/en/optimization/xdit.md b/docs/source/en/optimization/xdit.md
index 33ff8dc255d0..ecf45635684a 100644
--- a/docs/source/en/optimization/xdit.md
+++ b/docs/source/en/optimization/xdit.md
@@ -2,7 +2,7 @@
[xDiT](https://github.com/xdit-project/xDiT) is an inference engine designed for the large scale parallel deployment of Diffusion Transformers (DiTs). xDiT provides a suite of efficient parallel approaches for Diffusion Models, as well as GPU kernel accelerations.
-There are four parallel methods supported in xDiT, including [Unified Sequence Parallelism](https://arxiv.org/abs/2405.07719), [PipeFusion](https://arxiv.org/abs/2405.14430), CFG parallelism and data parallelism. The four parallel methods in xDiT can be configured in a hybrid manner, optimizing communication patterns to best suit the underlying network hardware.
+There are four parallel methods supported in xDiT, including [Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719), [PipeFusion](https://huggingface.co/papers/2405.14430), CFG parallelism and data parallelism. The four parallel methods in xDiT can be configured in a hybrid manner, optimizing communication patterns to best suit the underlying network hardware.
Optimization orthogonal to parallelization focuses on accelerating single GPU performance. In addition to utilizing well-known Attention optimization libraries, we leverage compilation acceleration technologies such as torch.compile and onediff.
@@ -116,6 +116,6 @@ More detailed performance metric can be found on our [github page](https://githu
[xDiT-project](https://github.com/xdit-project/xDiT)
-[USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719)
+[USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://huggingface.co/papers/2405.07719)
-[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://arxiv.org/abs/2405.14430)
\ No newline at end of file
+[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://huggingface.co/papers/2405.14430)
\ No newline at end of file
diff --git a/docs/source/en/optimization/xformers.md b/docs/source/en/optimization/xformers.md
index 4ef0da9e890d..523e81559547 100644
--- a/docs/source/en/optimization/xformers.md
+++ b/docs/source/en/optimization/xformers.md
@@ -1,4 +1,4 @@
-
+
+# NVIDIA ModelOpt
+
+[NVIDIA-ModelOpt](https://github.com/NVIDIA/Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed.
+
+Before you begin, make sure you have nvidia_modelopt installed.
+
+```bash
+pip install -U "nvidia_modelopt[hf]"
+```
+
+Quantize a model by passing [`NVIDIAModelOptConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
+
+The example below only quantizes the weights to FP8.
+
+```python
+import torch
+from diffusers import AutoModel, SanaPipeline, NVIDIAModelOptConfig
+
+model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
+dtype = torch.bfloat16
+
+quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt")
+transformer = AutoModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=dtype,
+)
+pipe = SanaPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=dtype,
+)
+pipe.to("cuda")
+
+print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
+
+prompt = "A cat holding a sign that says hello world"
+image = pipe(
+ prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
+).images[0]
+image.save("output.png")
+```
+
+> **Note:**
+>
+> The quantization methods in NVIDIA-ModelOpt are designed to reduce the memory footprint of model weights using various QAT (Quantization-Aware Training) and PTQ (Post-Training Quantization) techniques while maintaining model performance. However, the actual performance gain during inference depends on the deployment framework (e.g., TRT-LLM, TensorRT) and the specific hardware configuration.
+>
+> More details can be found [here](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples).
+
+## NVIDIAModelOptConfig
+
+The `NVIDIAModelOptConfig` class accepts three parameters:
+- `quant_type`: A string value mentioning one of the quantization types below.
+- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`SD3Transformer2DModel`]'s pos_embed projection blocks, one would specify: `modules_to_not_convert=["pos_embed.proj.weight"]`.
+- `disable_conv_quantization`: A boolean value which when set to `True` disables quantization for all convolutional layers in the model. This is useful as channel and block quantization generally don't work well with convolutional layers (used with INT4, NF4, NVFP4). If you want to disable quantization for specific convolutional layers, use `modules_to_not_convert` instead.
+- `algorithm`: The algorithm to use for determining scale, defaults to `"max"`. You can check modelopt documentation for more algorithms and details.
+- `forward_loop`: The forward loop function to use for calibrating activation during quantization. If not provided, it relies on static scale values computed using the weights only.
+- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
+
+## Supported quantization types
+
+ModelOpt supports weight-only, channel and block quantization int8, fp8, int4, nf4, and nvfp4. The quantization methods are designed to reduce the memory footprint of the model weights while maintaining the performance of the model during inference.
+
+Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
+
+The quantization methods supported are as follows:
+
+| **Quantization Type** | **Supported Schemes** | **Required Kwargs** | **Additional Notes** |
+|-----------------------|-----------------------|---------------------|----------------------|
+| **INT8** | `int8 weight only`, `int8 channel quantization`, `int8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` |
+| **FP8** | `fp8 weight only`, `fp8 channel quantization`, `fp8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` |
+| **INT4** | `int4 weight only`, `int4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`|
+| **NF4** | `nf4 weight only`, `nf4 double block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize + scale_channel_quantize` + `scale_block_quantize` | `channel_quantize = -1 and scale_channel_quantize = -1 are only supported for now` |
+| **NVFP4** | `nvfp4 weight only`, `nvfp4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`|
+
+
+Refer to the [official modelopt documentation](https://nvidia.github.io/Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
+
+## Serializing and Deserializing quantized models
+
+To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
+
+```python
+import torch
+from diffusers import AutoModel, NVIDIAModelOptConfig
+from modelopt.torch.opt import enable_huggingface_checkpointing
+
+enable_huggingface_checkpointing()
+
+model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
+quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"}
+quant_config_fp8 = NVIDIAModelOptConfig(**quant_config_fp8)
+model = AutoModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quant_config_fp8,
+ torch_dtype=torch.bfloat16,
+)
+model.save_pretrained('path/to/sana_fp8', safe_serialization=False)
+```
+
+To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
+
+```python
+import torch
+from diffusers import AutoModel, NVIDIAModelOptConfig, SanaPipeline
+from modelopt.torch.opt import enable_huggingface_checkpointing
+
+enable_huggingface_checkpointing()
+
+quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt")
+transformer = AutoModel.from_pretrained(
+ "path/to/sana_fp8",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+)
+pipe = SanaPipeline.from_pretrained(
+ "Efficient-Large-Model/Sana_600M_1024px_diffusers",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+)
+pipe.to("cuda")
+prompt = "A cat holding a sign that says hello world"
+image = pipe(
+ prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
+).images[0]
+image.save("output.png")
+```
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
index 93323f86c7fc..38abeeac6d4d 100644
--- a/docs/source/en/quantization/overview.md
+++ b/docs/source/en/quantization/overview.md
@@ -1,4 +1,4 @@
-
-# Quantization
+# Getting started
-Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
+Quantization focuses on representing data with fewer bits while also trying to preserve the precision of the original data. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
-
+Diffusers supports multiple quantization backends to make large diffusion models like [Flux](../api/pipelines/flux) more accessible. This guide shows how to use the [`~quantizers.PipelineQuantizationConfig`] class to quantize a pipeline during its initialization from a pretrained or non-quantized checkpoint.
-Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
+## Pipeline-level quantization
-
+There are two ways to use [`~quantizers.PipelineQuantizationConfig`] depending on how much customization you want to apply to the quantization configuration.
-
+- for basic use cases, define the `quant_backend`, `quant_kwargs`, and `components_to_quantize` arguments
+- for granular quantization control, define a `quant_mapping` that provides the quantization configuration for individual model components
-If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI:
+### Basic quantization
-* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
-* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)
+Initialize [`~quantizers.PipelineQuantizationConfig`] with the following parameters.
-
+- `quant_backend` specifies which quantization backend to use. Currently supported backends include: `bitsandbytes_4bit`, `bitsandbytes_8bit`, `gguf`, `quanto`, and `torchao`.
+- `quant_kwargs` specifies the quantization arguments to use.
-## When to use what?
+> [!TIP]
+> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.
-Diffusers currently supports the following quantization methods.
-- [BitsandBytes](./bitsandbytes)
-- [TorchAO](./torchao)
-- [GGUF](./gguf)
-- [Quanto](./quanto.md)
+- `components_to_quantize` specifies which component(s) of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
-[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
+ `components_to_quantize` accepts either a list for multiple models or a string for a single model.
+
+The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder_2"],
+)
+```
+
+Pass the `pipeline_quant_config` to [`~DiffusionPipeline.from_pretrained`] to quantize the pipeline.
+
+```py
+pipe = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+image = pipe("photo of a cute dog").images[0]
+```
+
+
+### Advanced quantization
+
+The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.
+
+Initialize [`~quantizers.PipelineQuantizationConfig`] and pass a `quant_mapping` to it. The `quant_mapping` allows you to specify the quantization options for each component in the pipeline such as the transformer and text encoder.
+
+The example below uses two quantization backends, [`~quantizers.quantization_config.QuantoConfig`] and [`transformers.BitsAndBytesConfig`], for the transformer and text encoder.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
+from diffusers.quantizers.quantization_config import QuantoConfig
+from diffusers.quantizers import PipelineQuantizationConfig
+from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
+
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": QuantoConfig(weights_dtype="int8"),
+ "text_encoder_2": TransformersBitsAndBytesConfig(
+ load_in_4bit=True, compute_dtype=torch.bfloat16
+ ),
+ }
+)
+```
+
+There is a separate bitsandbytes backend in [Transformers](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig). You need to import and use [`transformers.BitsAndBytesConfig`] for components that come from Transformers. For example, `text_encoder_2` in [`FluxPipeline`] is a [`~transformers.T5EncoderModel`] from Transformers so you need to use [`transformers.BitsAndBytesConfig`] instead of [`diffusers.BitsAndBytesConfig`].
+
+> [!TIP]
+> Use the [basic quantization](#basic-quantization) method above if you don't want to manage these distinct imports or aren't sure where each pipeline component comes from.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
+from diffusers.quantizers import PipelineQuantizationConfig
+from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
+
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": DiffusersBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16),
+ "text_encoder_2": TransformersBitsAndBytesConfig(
+ load_in_4bit=True, compute_dtype=torch.bfloat16
+ ),
+ }
+)
+```
+
+Pass the `pipeline_quant_config` to [`~DiffusionPipeline.from_pretrained`] to quantize the pipeline.
+
+```py
+pipe = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+image = pipe("photo of a cute dog").images[0]
+```
+
+## Resources
+
+Check out the resources below to learn more about quantization.
+
+- If you are new to quantization, we recommend checking out the following beginner-friendly courses in collaboration with DeepLearning.AI.
+
+ - [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
+ - [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)
+
+- Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) if you're interested in adding a new quantization method.
+
+- The Transformers quantization [Overview](https://huggingface.co/docs/transformers/quantization/overview#when-to-use-what) provides an overview of the pros and cons of different quantization backends.
+
+- Read the [Exploring Quantization Backends in Diffusers](https://huggingface.co/blog/diffusers-quantization) blog post for a brief introduction to each quantization backend, how to choose a backend, and combining quantization with other memory optimizations.
diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md
index 19a8970fa9df..18cc109e0785 100644
--- a/docs/source/en/quantization/torchao.md
+++ b/docs/source/en/quantization/torchao.md
@@ -1,4 +1,4 @@
-
# torchao
-[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
+[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
-Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
+Make sure Pytorch 2.5+ and torchao are installed with the command below.
```bash
-pip install -U torch torchao
+uv pip install -U torch torchao
```
+Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments.
-Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
+Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`].
-The example below only quantizes the weights to int8.
-
-```python
+```py
import torch
-from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
+from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
+from torchao.quantization import Int8WeightOnlyConfig
-model_id = "black-forest-labs/FLUX.1-dev"
-dtype = torch.bfloat16
-
-quantization_config = TorchAoConfig("int8wo")
-transformer = FluxTransformer2DModel.from_pretrained(
- model_id,
- subfolder="transformer",
- quantization_config=quantization_config,
- torch_dtype=dtype,
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
)
-pipe = FluxPipeline.from_pretrained(
- model_id,
- transformer=transformer,
- torch_dtype=dtype,
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantzation_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
)
-pipe.to("cuda")
+```
-# Without quantization: ~31.447 GB
-# With quantization: ~20.40 GB
-print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
+For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
-prompt = "A cat holding a sign that says hello world"
-image = pipe(
- prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
-).images[0]
-image.save("output.png")
+```py
+import torch
+from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
+
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_mapping={"transformer": TorchAoConfig("int8wo")}
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantzation_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
```
-TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
+## torch.compile
+
+torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code.
```python
-# In the above code, add the following after initializing the transformer
-transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
+import torch
+from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
+from torchao.quantization import Int4WeightOnlyConfig
+
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantzation_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+
+pipeline.transformer.compile(transformer, mode="max-autotune", fullgraph=True)
```
-For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
+Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
+
+> [!TIP]
+> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
-torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
+## autoquant
-The `TorchAoConfig` class accepts three parameters:
-- `quant_type`: A string value mentioning one of the quantization types below.
-- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
-- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
+torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from torchao.quantization import autoquant
+
+# Load the pipeline
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+
+transformer = autoquant(pipeline.transformer)
+```
## Supported quantization types
@@ -85,13 +115,13 @@ The quantization methods supported are as follows:
| **Category** | **Full Function Names** | **Shorthands** |
|--------------|-------------------------|----------------|
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
-| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
+| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8dq_e4m3_tensor`, `float8dq_e4m3_row` |
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
-Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
+Refer to the [official torchao documentation](https://docs.pytorch.org/ao/stable/index.html) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
## Serializing and Deserializing quantized models
@@ -99,10 +129,10 @@ To serialize a quantized model in a given dtype, first load the model with the d
```python
import torch
-from diffusers import FluxTransformer2DModel, TorchAoConfig
+from diffusers import AutoModel, TorchAoConfig
quantization_config = TorchAoConfig("int8wo")
-transformer = FluxTransformer2DModel.from_pretrained(
+transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=quantization_config,
@@ -115,9 +145,9 @@ To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] me
```python
import torch
-from diffusers import FluxPipeline, FluxTransformer2DModel
+from diffusers import FluxPipeline, AutoModel
-transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
+transformer = AutoModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
@@ -131,10 +161,10 @@ If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, c
```python
import torch
from accelerate import init_empty_weights
-from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
+from diffusers import FluxPipeline, AutoModel, TorchAoConfig
# Serialize the model
-transformer = FluxTransformer2DModel.from_pretrained(
+transformer = AutoModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=TorchAoConfig("uint4wo"),
@@ -146,11 +176,14 @@ transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, m
# Load the model
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
with init_empty_weights():
- transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
+ transformer = AutoModel.from_config("/path/to/flux_uint4wo/config.json")
transformer.load_state_dict(state_dict, strict=True, assign=True)
```
+> [!TIP]
+> The [`AutoModel`] API is supported for PyTorch >= 2.6 as shown in the examples below.
+
## Resources
-- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
+- [TorchAO Quantization API](https://docs.pytorch.org/ao/stable/index.html)
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
diff --git a/docs/source/en/quicktour.md b/docs/source/en/quicktour.md
index 2d9f7fe3736a..1ccc8eeadcc2 100644
--- a/docs/source/en/quicktour.md
+++ b/docs/source/en/quicktour.md
@@ -1,4 +1,4 @@
-
-[[open-in-colab]]
+# Quickstart
-# Quicktour
+Diffusers is a library for developers and researchers that provides an easy inference API for generating images, videos and audio, as well as the building blocks for implementing new workflows.
-Diffusion models are trained to denoise random Gaussian noise step-by-step to generate a sample of interest, such as an image or audio. This has sparked a tremendous amount of interest in generative AI, and you have probably seen examples of diffusion generated images on the internet. 🧨 Diffusers is a library aimed at making diffusion models widely accessible to everyone.
+Diffusers provides many optimizations out-of-the-box that makes it possible to load and run large models on setups with limited memory or to accelerate inference.
-Whether you're a developer or an everyday user, this quicktour will introduce you to 🧨 Diffusers and help you get up and generating quickly! There are three main components of the library to know about:
+This Quickstart will give you an overview of Diffusers and get you up and generating quickly.
-* The [`DiffusionPipeline`] is a high-level end-to-end class designed to rapidly generate samples from pretrained diffusion models for inference.
-* Popular pretrained [model](./api/models) architectures and modules that can be used as building blocks for creating diffusion systems.
-* Many different [schedulers](./api/schedulers/overview) - algorithms that control how noise is added for training, and how to generate denoised images during inference.
+> [!TIP]
+> Before you begin, make sure you have a Hugging Face [account](https://huggingface.co/join) in order to use gated models like [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev).
-The quicktour will show you how to use the [`DiffusionPipeline`] for inference, and then walk you through how to combine a model and scheduler to replicate what's happening inside the [`DiffusionPipeline`].
-
-
-
-The quicktour is a simplified version of the introductory 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) to help you get started quickly. If you want to learn more about 🧨 Diffusers' goal, design philosophy, and additional details about its core API, check out the notebook!
-
-
-
-Before you begin, make sure you have all the necessary libraries installed:
-
-```py
-# uncomment to install the necessary libraries in Colab
-#!pip install --upgrade diffusers accelerate transformers
-```
-
-- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) speeds up model loading for inference and training.
-- [🤗 Transformers](https://huggingface.co/docs/transformers/index) is required to run the most popular diffusion models, such as [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).
+Follow the [Installation](./installation) guide to install Diffusers if it's not already installed.
## DiffusionPipeline
-The [`DiffusionPipeline`] is the easiest way to use a pretrained diffusion system for inference. It is an end-to-end system containing the model and the scheduler. You can use the [`DiffusionPipeline`] out-of-the-box for many tasks. Take a look at the table below for some supported tasks, and for a complete list of supported tasks, check out the [🧨 Diffusers Summary](./api/pipelines/overview#diffusers-summary) table.
+A diffusion model combines multiple components to generate outputs in any modality based on an input, such as a text description, image or both.
-| **Task** | **Description** | **Pipeline**
-|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|
-| Unconditional Image Generation | generate an image from Gaussian noise | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |
-| Text-Guided Image Generation | generate an image given a text prompt | [conditional_image_generation](./using-diffusers/conditional_image_generation) |
-| Text-Guided Image-to-Image Translation | adapt an image guided by a text prompt | [img2img](./using-diffusers/img2img) |
-| Text-Guided Image-Inpainting | fill the masked part of an image given the image, the mask and a text prompt | [inpaint](./using-diffusers/inpaint) |
-| Text-Guided Depth-to-Image Translation | adapt parts of an image guided by a text prompt while preserving structure via depth estimation | [depth2img](./using-diffusers/depth2img) |
+For a standard text-to-image model:
-Start by creating an instance of a [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
-You can use the [`DiffusionPipeline`] for any [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) stored on the Hugging Face Hub.
-In this quicktour, you'll load the [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint for text-to-image generation.
+1. A text encoder turns a prompt into embeddings that guide the denoising process. Some models have more than one text encoder.
+2. A scheduler contains the algorithmic specifics for gradually denoising initial random noise into clean outputs. Different schedulers affect generation speed and quality.
+3. A UNet or diffusion transformer (DiT) is the workhorse of a diffusion model.
-
+ At each step, it performs the denoising predictions, such as how much noise to remove or the general direction in which to steer the noise to generate better quality outputs.
-For [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) models, please carefully read the [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) first before running the model. 🧨 Diffusers implements a [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) to prevent offensive or harmful content, but the model's improved image generation capabilities can still produce potentially harmful content.
+ The UNet or DiT repeats this loop for a set amount of steps to generate the final output.
+
+4. A variational autoencoder (VAE) encodes and decodes pixels to a spatially compressed latent-space. *Latents* are compressed representations of an image and are more efficient to work with. The UNet or DiT operates on latents, and the clean latents at the end are decoded back into images.
-
+The [`DiffusionPipeline`] packages all these components into a single class for inference. There are several arguments in [`~DiffusionPipeline.__call__`] you can change, such as `num_inference_steps`, that affect the diffusion process. Try different values and arguments to see how they change generation quality or speed.
-Load the model with the [`~DiffusionPipeline.from_pretrained`] method:
+Load a model with [`~DiffusionPipeline.from_pretrained`] and describe what you'd like to generate. The example below uses the default argument values.
-```python
->>> from diffusers import DiffusionPipeline
+
+
->>> pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-```
-
-The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. You'll see that the Stable Diffusion pipeline is composed of the [`UNet2DConditionModel`] and [`PNDMScheduler`] among other things:
+Use `.images[0]` to access the generated image output.
```py
->>> pipeline
-StableDiffusionPipeline {
- "_class_name": "StableDiffusionPipeline",
- "_diffusers_version": "0.21.4",
- ...,
- "scheduler": [
- "diffusers",
- "PNDMScheduler"
- ],
- ...,
- "unet": [
- "diffusers",
- "UNet2DConditionModel"
- ],
- "vae": [
- "diffusers",
- "AutoencoderKL"
- ]
-}
-```
-
-We strongly recommend running the pipeline on a GPU because the model consists of roughly 1.4 billion parameters.
-You can move the generator object to a GPU, just like you would in PyTorch:
+import torch
+from diffusers import DiffusionPipeline
-```python
->>> pipeline.to("cuda")
-```
-
-Now you can pass a text prompt to the `pipeline` to generate an image, and then access the denoised image. By default, the image output is wrapped in a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
-```python
->>> image = pipeline("An image of a squirrel in Picasso style").images[0]
->>> image
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
```
-
-
-
+
+
-Save the image by calling `save`:
-
-```python
->>> image.save("image_of_squirrel_painting.png")
-```
-
-### Local pipeline
-
-You can also use the pipeline locally. The only difference is you need to download the weights first:
-
-```bash
-!git lfs install
-!git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5
-```
-
-Then load the saved weights into the pipeline:
-
-```python
->>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", use_safetensors=True)
-```
-
-Now, you can run the pipeline as you would in the section above.
-
-### Swapping schedulers
-
-Different schedulers come with different denoising speeds and quality trade-offs. The best way to find out which one works best for you is to try them out! One of the main features of 🧨 Diffusers is to allow you to easily switch between schedulers. For example, to replace the default [`PNDMScheduler`] with the [`EulerDiscreteScheduler`], load it with the [`~diffusers.ConfigMixin.from_config`] method:
+Use `.frames[0]` to access the generated video output and [`~utils.export_to_video`] to save the video.
```py
->>> from diffusers import EulerDiscreteScheduler
-
->>> pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
->>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
-```
-
-Try generating an image with the new scheduler and see if you notice a difference!
-
-In the next section, you'll take a closer look at the components - the model and scheduler - that make up the [`DiffusionPipeline`] and learn how to use these components to generate an image of a cat.
-
-## Models
-
-Most models take a noisy sample, and at each timestep it predicts the *noise residual* (other models learn to predict the previous sample directly or the velocity or [`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)), the difference between a less noisy image and the input image. You can mix and match models to create other diffusion systems.
-
-Models are initiated with the [`~ModelMixin.from_pretrained`] method which also locally caches the model weights so it is faster the next time you load the model. For the quicktour, you'll load the [`UNet2DModel`], a basic unconditional image generation model with a checkpoint trained on cat images:
+import torch
+from diffusers import AutoencoderKLWan, DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+from diffusers.utils import export_to_video
+
+vae = AutoencoderKLWan.from_pretrained(
+ "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ subfolder="vae",
+ torch_dtype=torch.float32
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ vae=vae
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+
+prompt = """
+Cinematic video of a sleek cat lounging on a colorful inflatable in a crystal-clear turquoise pool in Palm Springs,
+sipping a salt-rimmed margarita through a straw. Golden-hour sunlight glows over mid-century modern homes and swaying palms.
+Shot in rich Sony a7S III: with moody, glamorous color grading, subtle lens flares, and soft vintage film grain.
+Ripples shimmer as a warm desert breeze stirs the water, blending luxury and playful charm in an epic, gorgeously composed frame.
+"""
+video = pipeline(prompt=prompt, num_frames=81, num_inference_steps=40).frames[0]
+export_to_video(video, "output.mp4", fps=16)
+```
+
+
+
+
+## LoRA
+
+Adapters insert a small number of trainable parameters to the original base model. Only the inserted parameters are fine-tuned while the rest of the model weights remain frozen. This makes it fast and cheap to fine-tune a model on a new style. Among adapters, [LoRA's](./tutorials/using_peft_for_inference) are the most popular.
+
+Add a LoRA to a pipeline with the [`~loaders.QwenImageLoraLoaderMixin.load_lora_weights`] method. Some LoRA's require a special word to trigger it, such as `Realism`, in the example below. Check a LoRA's model card to see if it requires a trigger word.
```py
->>> from diffusers import UNet2DModel
-
->>> repo_id = "google/ddpm-cat-256"
->>> model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)
-```
+import torch
+from diffusers import DiffusionPipeline
-To access the model parameters, call `model.config`:
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+pipeline.load_lora_weights(
+ "flymy-ai/qwen-image-realism-lora",
+)
-```py
->>> model.config
+prompt = """
+super Realism cinematic film still of a cat sipping a margarita in a pool in Palm Springs in the style of umempart, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
```
-The model configuration is a 🧊 frozen 🧊 dictionary, which means those parameters can't be changed after the model is created. This is intentional and ensures that the parameters used to define the model architecture at the start remain the same, while other parameters can still be adjusted during inference.
-
-Some of the most important parameters are:
+Check out the [LoRA](./tutorials/using_peft_for_inference) docs or Adapters section to learn more.
-* `sample_size`: the height and width dimension of the input sample.
-* `in_channels`: the number of input channels of the input sample.
-* `down_block_types` and `up_block_types`: the type of down- and upsampling blocks used to create the UNet architecture.
-* `block_out_channels`: the number of output channels of the downsampling blocks; also used in reverse order for the number of input channels of the upsampling blocks.
-* `layers_per_block`: the number of ResNet blocks present in each UNet block.
+## Quantization
-To use the model for inference, create the image shape with random Gaussian noise. It should have a `batch` axis because the model can receive multiple random noises, a `channel` axis corresponding to the number of input channels, and a `sample_size` axis for the height and width of the image:
+[Quantization](./quantization/overview) stores data in fewer bits to reduce memory usage. It may also speed up inference because it takes less time to perform calculations with fewer bits.
-```py
->>> import torch
+Diffusers provides several quantization backends and picking one depends on your use case. For example, [bitsandbytes](./quantization/bitsandbytes) and [torchao](./quantization/torchao) are both simple and easy to use for inference, but torchao supports more [quantization types](./quantization/torchao#supported-quantization-types) like fp8.
->>> torch.manual_seed(0)
-
->>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
->>> noisy_sample.shape
-torch.Size([1, 3, 256, 256])
-```
-
-For inference, pass the noisy image and a `timestep` to the model. The `timestep` indicates how noisy the input image is, with more noise at the beginning and less at the end. This helps the model determine its position in the diffusion process, whether it is closer to the start or the end. Use the `sample` method to get the model output:
+Configure [`PipelineQuantizationConfig`] with the backend to use, the specific arguments (refer to the [API](./api/quantization) reference for available arguments) for that backend, and which components to quantize. The example below quantizes the model to 4-bits and only uses 14.93GB of memory.
```py
->>> with torch.no_grad():
-... noisy_residual = model(sample=noisy_sample, timestep=2).sample
-```
-
-To generate actual examples though, you'll need a scheduler to guide the denoising process. In the next section, you'll learn how to couple a model with a scheduler.
-
-## Schedulers
-
-Schedulers manage going from a noisy sample to a less noisy sample given the model output - in this case, it is the `noisy_residual`.
-
-
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
-🧨 Diffusers is a toolbox for building diffusion systems. While the [`DiffusionPipeline`] is a convenient way to get started with a pre-built diffusion system, you can also choose your own model and scheduler components separately to build a custom diffusion system.
+quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder"],
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype=torch.bfloat16,
+ quantization_config=quant_config,
+ device_map="cuda"
+)
-
-
-For the quicktour, you'll instantiate the [`DDPMScheduler`] with its [`~diffusers.ConfigMixin.from_config`] method:
-
-```py
->>> from diffusers import DDPMScheduler
-
->>> scheduler = DDPMScheduler.from_pretrained(repo_id)
->>> scheduler
-DDPMScheduler {
- "_class_name": "DDPMScheduler",
- "_diffusers_version": "0.21.4",
- "beta_end": 0.02,
- "beta_schedule": "linear",
- "beta_start": 0.0001,
- "clip_sample": true,
- "clip_sample_range": 1.0,
- "dynamic_thresholding_ratio": 0.995,
- "num_train_timesteps": 1000,
- "prediction_type": "epsilon",
- "sample_max_value": 1.0,
- "steps_offset": 0,
- "thresholding": false,
- "timestep_spacing": "leading",
- "trained_betas": null,
- "variance_type": "fixed_small"
-}
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
-
-
-💡 Unlike a model, a scheduler does not have trainable weights and is parameter-free!
+Take a look at the [Quantization](./quantization/overview) section for more details.
-
+## Optimizations
-Some of the most important parameters are:
+> [!TIP]
+> Optimization is dependent on hardware specs such as memory. Use this [Space](https://huggingface.co/spaces/diffusers/optimized-diffusers-code) to generate code examples that include all of Diffusers' available memory and speed optimization techniques for any model you're using.
-* `num_train_timesteps`: the length of the denoising process or, in other words, the number of timesteps required to process random Gaussian noise into a data sample.
-* `beta_schedule`: the type of noise schedule to use for inference and training.
-* `beta_start` and `beta_end`: the start and end noise values for the noise schedule.
+Modern diffusion models are very large and have billions of parameters. The iterative denoising process is also computationally intensive and slow. Diffusers provides techniques for reducing memory usage and boosting inference speed. These techniques can be combined with quantization to optimize for both memory usage and inference speed.
-To predict a slightly less noisy image, pass the following to the scheduler's [`~diffusers.DDPMScheduler.step`] method: model output, `timestep`, and current `sample`.
+### Memory usage
-```py
->>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample
->>> less_noisy_sample.shape
-torch.Size([1, 3, 256, 256])
-```
+The text encoders and UNet or DiT can use up as much as ~30GB of memory, exceeding the amount available on many free-tier or consumer GPUs.
-The `less_noisy_sample` can be passed to the next `timestep` where it'll get even less noisy! Let's bring it all together now and visualize the entire denoising process.
+Offloading stores weights that aren't currently used on the CPU and only moves them to the GPU when they're needed. There are a few offloading types and the example below uses [model offloading](./optimization/memory#model-offloading). This moves an entire model, like a text encoder or transformer, to the CPU when it isn't actively being used.
-First, create a function that postprocesses and displays the denoised image as a `PIL.Image`:
+Call [`~DiffusionPipeline.enable_model_cpu_offload`] to activate it. By combining quantization and offloading, the following example only requires ~12.54GB of memory.
```py
->>> import PIL.Image
->>> import numpy as np
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder"],
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype=torch.bfloat16,
+ quantization_config=quant_config,
+ device_map="cuda"
+)
+pipeline.enable_model_cpu_offload()
->>> def display_sample(sample, i):
-... image_processed = sample.cpu().permute(0, 2, 3, 1)
-... image_processed = (image_processed + 1.0) * 127.5
-... image_processed = image_processed.numpy().astype(np.uint8)
-
-... image_pil = PIL.Image.fromarray(image_processed[0])
-... display(f"Image at step {i}")
-... display(image_pil)
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
-To speed up the denoising process, move the input and model to a GPU:
+Refer to the [Reduce memory usage](./optimization/memory) docs to learn more about other memory reducing techniques.
-```py
->>> model.to("cuda")
->>> noisy_sample = noisy_sample.to("cuda")
-```
+### Inference speed
-Now create a denoising loop that predicts the residual of the less noisy sample, and computes the less noisy sample with the scheduler:
+The denoising loop performs a lot of computations and can be slow. Methods like [torch.compile](./optimization/fp16#torchcompile) increases inference speed by compiling the computations into an optimized kernel. Compilation is slow for the first generation but successive generations should be much faster.
-```py
->>> import tqdm
+The example below uses [regional compilation](./optimization/fp16#regional-compilation) to only compile small regions of a model. It reduces cold-start latency while also providing a runtime speed up.
->>> sample = noisy_sample
+Call [`~ModelMixin.compile_repeated_blocks`] on the model to activate it.
->>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
-... # 1. predict noise residual
-... with torch.no_grad():
-... residual = model(sample, t).sample
+```py
+import torch
+from diffusers import DiffusionPipeline
-... # 2. compute less noisy image and set x_t -> x_t-1
-... sample = scheduler.step(residual, t, sample).prev_sample
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
-... # 3. optionally look at image
-... if (i + 1) % 50 == 0:
-... display_sample(sample, i + 1)
+pipeline.transformer.compile_repeated_blocks(
+ fullgraph=True,
+)
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
```
-Sit back and watch as a cat is generated from nothing but noise! 😻
-
-
-
-
-
-## Next steps
-
-Hopefully, you generated some cool images with 🧨 Diffusers in this quicktour! For your next steps, you can:
-
-* Train or finetune a model to generate your own images in the [training](./tutorials/basic_training) tutorial.
-* See example official and community [training or finetuning scripts](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples) for a variety of use cases.
-* Learn more about loading, accessing, changing, and comparing schedulers in the [Using different Schedulers](./using-diffusers/schedulers) guide.
-* Explore prompt engineering, speed and memory optimizations, and tips and tricks for generating higher-quality images with the [Stable Diffusion](./stable_diffusion) guide.
-* Dive deeper into speeding up 🧨 Diffusers with guides on [optimized PyTorch on a GPU](./optimization/fp16), and inference guides for running [Stable Diffusion on Apple Silicon (M1/M2)](./optimization/mps) and [ONNX Runtime](./optimization/onnx).
+Check out the [Accelerate inference](./optimization/fp16) or [Caching](./optimization/cache) docs for more methods that speed up inference.
\ No newline at end of file
diff --git a/docs/source/en/stable_diffusion.md b/docs/source/en/stable_diffusion.md
index fc20d259f5f7..93e399d3db88 100644
--- a/docs/source/en/stable_diffusion.md
+++ b/docs/source/en/stable_diffusion.md
@@ -1,4 +1,4 @@
-
-# Effective and efficient diffusion
-
[[open-in-colab]]
-Getting the [`DiffusionPipeline`] to generate images in a certain style or include what you want can be tricky. Often times, you have to run the [`DiffusionPipeline`] several times before you end up with an image you're happy with. But generating something out of nothing is a computationally intensive process, especially if you're running inference over and over again.
-
-This is why it's important to get the most *computational* (speed) and *memory* (GPU vRAM) efficiency from the pipeline to reduce the time between inference cycles so you can iterate faster.
-
-This tutorial walks you through how to generate faster and better with the [`DiffusionPipeline`].
-
-Begin by loading the [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) model:
-
-```python
-from diffusers import DiffusionPipeline
-
-model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
-```
-
-The example prompt you'll use is a portrait of an old warrior chief, but feel free to use your own prompt:
-
-```python
-prompt = "portrait photo of a old warrior chief"
-```
+# Basic performance
-## Speed
+Diffusion is a random process that is computationally demanding. You may need to run the [`DiffusionPipeline`] several times before getting a desired output. That's why it's important to carefully balance generation speed and memory usage in order to iterate faster,
-
+This guide recommends some basic performance tips for using the [`DiffusionPipeline`]. Refer to the Inference Optimization section docs such as [Accelerate inference](./optimization/fp16) or [Reduce memory usage](./optimization/memory) for more detailed performance guides.
-💡 If you don't have access to a GPU, you can use one for free from a GPU provider like [Colab](https://colab.research.google.com/)!
+## Memory usage
-
-
-One of the simplest ways to speed up inference is to place the pipeline on a GPU the same way you would with any PyTorch module:
-
-```python
-pipeline = pipeline.to("cuda")
-```
+Reducing the amount of memory used indirectly speeds up generation and can help a model fit on device.
-To make sure you can use the same image and improve on it, use a [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed for [reproducibility](./using-diffusers/reusing_seeds):
+The [`~DiffusionPipeline.enable_model_cpu_offload`] method moves a model to the CPU when it is not in use to save GPU memory.
-```python
+```py
import torch
+from diffusers import DiffusionPipeline
-generator = torch.Generator("cuda").manual_seed(0)
-```
-
-Now you can generate an image:
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+pipeline.enable_model_cpu_offload()
-```python
-image = pipeline(prompt, generator=generator).images[0]
-image
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
-
-
-
+## Inference speed
-This process took ~30 seconds on a T4 GPU (it might be faster if your allocated GPU is better than a T4). By default, the [`DiffusionPipeline`] runs inference with full `float32` precision for 50 inference steps. You can speed this up by switching to a lower precision like `float16` or running fewer inference steps.
+Denoising is the most computationally demanding process during diffusion. Methods that optimizes this process accelerates inference speed. Try the following methods for a speed up.
-Let's start by loading the model in `float16` and generate an image:
+- Add `device_map="cuda"` to place the pipeline on a GPU. Placing a model on an accelerator, like a GPU, increases speed because it performs computations in parallel.
+- Set `torch_dtype=torch.bfloat16` to execute the pipeline in half-precision. Reducing the data type precision increases speed because it takes less time to perform computations in a lower precision.
-```python
+```py
import torch
+import time
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
-pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
-pipeline = pipeline.to("cuda")
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator).images[0]
-image
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda
+)
```
-
-
-
-
-This time, it only took ~11 seconds to generate the image, which is almost 3x faster than before!
-
-
-
-💡 We strongly suggest always running your pipelines in `float16`, and so far, we've rarely seen any degradation in output quality.
-
-
-
-Another option is to reduce the number of inference steps. Choosing a more efficient scheduler could help decrease the number of steps without sacrificing output quality. You can find which schedulers are compatible with the current model in the [`DiffusionPipeline`] by calling the `compatibles` method:
-
-```python
-pipeline.scheduler.compatibles
-[
- diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
- diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
- diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
- diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
- diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
- diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
- diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
- diffusers.utils.dummy_torch_and_torchsde_objects.DPMSolverSDEScheduler,
- diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
- diffusers.schedulers.scheduling_pndm.PNDMScheduler,
- diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_ddim.DDIMScheduler,
-]
-```
-
-The Stable Diffusion model uses the [`PNDMScheduler`] by default which usually requires ~50 inference steps, but more performant schedulers like [`DPMSolverMultistepScheduler`], require only ~20 or 25 inference steps. Use the [`~ConfigMixin.from_config`] method to load a new scheduler:
-
-```python
-from diffusers import DPMSolverMultistepScheduler
+- Use a faster scheduler, such as [`DPMSolverMultistepScheduler`], which only requires ~20-25 steps.
+- Set `num_inference_steps` to a lower value. Reducing the number of inference steps reduces the overall number of computations. However, this can result in lower generation quality.
+```py
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
-```
-
-Now set the `num_inference_steps` to 20:
-
-```python
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
-image
-```
-
-
-
-
-
-Great, you've managed to cut the inference time to just 4 seconds! ⚡️
-
-## Memory
-
-The other key to improving pipeline performance is consuming less memory, which indirectly implies more speed, since you're often trying to maximize the number of images generated per second. The easiest way to see how many images you can generate at once is to try out different batch sizes until you get an `OutOfMemoryError` (OOM).
-Create a function that'll generate a batch of images from a list of prompts and `Generators`. Make sure to assign each `Generator` a seed so you can reuse it if it produces a good result.
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
-```python
-def get_inputs(batch_size=1):
- generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
- prompts = batch_size * [prompt]
- num_inference_steps = 20
+start_time = time.perf_counter()
+image = pipeline(prompt).images[0]
+end_time = time.perf_counter()
- return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
+print(f"Image generation took {end_time - start_time:.3f} seconds")
```
-Start with `batch_size=4` and see how much memory you've consumed:
+## Generation quality
-```python
-from diffusers.utils import make_image_grid
+Many modern diffusion models deliver high-quality images out-of-the-box. However, you can still improve generation quality by trying the following.
-images = pipeline(**get_inputs(batch_size=4)).images
-make_image_grid(images, 2, 2)
-```
-
-Unless you have a GPU with more vRAM, the code above probably returned an `OOM` error! Most of the memory is taken up by the cross-attention layers. Instead of running this operation in a batch, you can run it sequentially to save a significant amount of memory. All you have to do is configure the pipeline to use the [`~DiffusionPipeline.enable_attention_slicing`] function:
-
-```python
-pipeline.enable_attention_slicing()
-```
-
-Now try increasing the `batch_size` to 8!
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-
-
-Whereas before you couldn't even generate a batch of 4 images, now you can generate a batch of 8 images at ~3.5 seconds per image! This is probably the fastest you can go on a T4 GPU without sacrificing quality.
-
-## Quality
-
-In the last two sections, you learned how to optimize the speed of your pipeline by using `fp16`, reducing the number of inference steps by using a more performant scheduler, and enabling attention slicing to reduce memory consumption. Now you're going to focus on how to improve the quality of generated images.
-
-### Better checkpoints
-
-The most obvious step is to use better checkpoints. The Stable Diffusion model is a good starting point, and since its official launch, several improved versions have also been released. However, using a newer version doesn't automatically mean you'll get better results. You'll still have to experiment with different checkpoints yourself, and do a little research (such as using [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) to get the best results.
-
-As the field grows, there are more and more high-quality checkpoints finetuned to produce certain styles. Try exploring the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) and [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) to find one you're interested in!
+- Try a more detailed and descriptive prompt. Include details such as the image medium, subject, style, and aesthetic. A negative prompt may also help by guiding a model away from undesirable features by using words like low quality or blurry.
-### Better pipeline components
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline
-You can also try replacing the current pipeline components with a newer version. Let's try loading the latest [autoencoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) from Stability AI into the pipeline, and generate some images:
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+ )
-```python
-from diffusers import AutoencoderKL
+ prompt = """
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+ """
+ negative_prompt = "low quality, blurry, ugly, poor details"
+ pipeline(prompt, negative_prompt=negative_prompt).images[0]
+ ```
-vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
-pipeline.vae = vae
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-
-
-### Better prompt engineering
-
-The text prompt you use to generate an image is super important, so much so that it is called *prompt engineering*. Some considerations to keep during prompt engineering are:
-
-- How is the image or similar images of the one I want to generate stored on the internet?
-- What additional detail can I give that steers the model towards the style I want?
-
-With this in mind, let's improve the prompt to include color and higher quality details:
-
-```python
-prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
-prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
-```
-
-Generate a batch of images with the new prompt:
+ For more details about creating better prompts, take a look at the [Prompt techniques](./using-diffusers/weighted_prompts) doc.
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
+- Try a different scheduler, like [`HeunDiscreteScheduler`] or [`LMSDiscreteScheduler`], that gives up generation speed for quality.
-
-
-
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline, HeunDiscreteScheduler
-Pretty impressive! Let's tweak the second image - corresponding to the `Generator` with a seed of `1` - a bit more by adding some text about the age of the subject:
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+ )
+ pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
-```python
-prompts = [
- "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of an old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
-]
-
-generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
-images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
-make_image_grid(images, 2, 2)
-```
-
-
-
-
+ prompt = """
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+ """
+ negative_prompt = "low quality, blurry, ugly, poor details"
+ pipeline(prompt, negative_prompt=negative_prompt).images[0]
+ ```
## Next steps
-In this tutorial, you learned how to optimize a [`DiffusionPipeline`] for computational and memory efficiency as well as improving the quality of generated outputs. If you're interested in making your pipeline even faster, take a look at the following resources:
-
-- Learn how [PyTorch 2.0](./optimization/torch2.0) and [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) can yield 5 - 300% faster inference speed. On an A100 GPU, inference can be up to 50% faster!
-- If you can't use PyTorch 2, we recommend you install [xFormers](./optimization/xformers). Its memory-efficient attention mechanism works great with PyTorch 1.13.1 for faster speed and reduced memory consumption.
-- Other optimization techniques, such as model offloading, are covered in [this guide](./optimization/fp16).
+Diffusers offers more advanced and powerful optimizations such as [group-offloading](./optimization/memory#group-offloading) and [regional compilation](./optimization/fp16#regional-compilation). To learn more about how to maximize performance, take a look at the Inference Optimization section.
\ No newline at end of file
diff --git a/docs/source/en/training/adapt_a_model.md b/docs/source/en/training/adapt_a_model.md
index e6a088675a34..9b7efd2abfd8 100644
--- a/docs/source/en/training/adapt_a_model.md
+++ b/docs/source/en/training/adapt_a_model.md
@@ -16,12 +16,12 @@ pipeline.unet.config["in_channels"]
4
```
-Inpainting requires 9 channels in the input sample. You can check this value in a pretrained inpainting model like [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting):
+Inpainting requires 9 channels in the input sample. You can check this value in a pretrained inpainting model like [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting):
```py
from diffusers import StableDiffusionPipeline
-pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", use_safetensors=True)
+pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-inpainting", use_safetensors=True)
pipeline.unet.config["in_channels"]
9
```
@@ -31,10 +31,10 @@ To adapt your text-to-image model for inpainting, you'll need to change the numb
Initialize a [`UNet2DConditionModel`] with the pretrained text-to-image model weights, and change `in_channels` to 9. Changing the number of `in_channels` means you need to set `ignore_mismatched_sizes=True` and `low_cpu_mem_usage=False` to avoid a size mismatch error because the shape is different now.
```py
-from diffusers import UNet2DConditionModel
+from diffusers import AutoModel
model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-unet = UNet2DConditionModel.from_pretrained(
+unet = AutoModel.from_pretrained(
model_id,
subfolder="unet",
in_channels=9,
diff --git a/docs/source/en/training/cogvideox.md b/docs/source/en/training/cogvideox.md
index 657e58bfd5eb..d0700c4da763 100644
--- a/docs/source/en/training/cogvideox.md
+++ b/docs/source/en/training/cogvideox.md
@@ -1,4 +1,4 @@
-
-
-# Accelerate inference of text-to-image diffusion models
-
-Diffusion models are slower than their GAN counterparts because of the iterative and sequential reverse diffusion process. There are several techniques that can address this limitation such as progressive timestep distillation ([LCM LoRA](../using-diffusers/inference_with_lcm_lora)), model compression ([SSD-1B](https://huggingface.co/segmind/SSD-1B)), and reusing adjacent features of the denoiser ([DeepCache](../optimization/deepcache)).
-
-However, you don't necessarily need to use these techniques to speed up inference. With PyTorch 2 alone, you can accelerate the inference latency of text-to-image diffusion pipelines by up to 3x. This tutorial will show you how to progressively apply the optimizations found in PyTorch 2 to reduce inference latency. You'll use the [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl) pipeline in this tutorial, but these techniques are applicable to other text-to-image diffusion pipelines too.
-
-Make sure you're using the latest version of Diffusers:
-
-```bash
-pip install -U diffusers
-```
-
-Then upgrade the other required libraries too:
-
-```bash
-pip install -U transformers accelerate peft
-```
-
-Install [PyTorch nightly](https://pytorch.org/) to benefit from the latest and fastest kernels:
-
-```bash
-pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
-```
-
-> [!TIP]
-> The results reported below are from a 80GB 400W A100 with its clock rate set to the maximum.
-> If you're interested in the full benchmarking code, take a look at [huggingface/diffusion-fast](https://github.com/huggingface/diffusion-fast).
-
-
-## Baseline
-
-Let's start with a baseline. Disable reduced precision and the [`scaled_dot_product_attention` (SDPA)](../optimization/torch2.0#scaled-dot-product-attention) function which is automatically used by Diffusers:
-
-```python
-from diffusers import StableDiffusionXLPipeline
-
-# Load the pipeline in full-precision and place its model components on CUDA.
-pipe = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0"
-).to("cuda")
-
-# Run the attention ops without SDPA.
-pipe.unet.set_default_attn_processor()
-pipe.vae.set_default_attn_processor()
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipe(prompt, num_inference_steps=30).images[0]
-```
-
-This default setup takes 7.36 seconds.
-
-
-
-
-
-## bfloat16
-
-Enable the first optimization, reduced precision or more specifically bfloat16. There are several benefits of using reduced precision:
-
-* Using a reduced numerical precision (such as float16 or bfloat16) for inference doesn’t affect the generation quality but significantly improves latency.
-* The benefits of using bfloat16 compared to float16 are hardware dependent, but modern GPUs tend to favor bfloat16.
-* bfloat16 is much more resilient when used with quantization compared to float16, but more recent versions of the quantization library ([torchao](https://github.com/pytorch-labs/ao)) we used don't have numerical issues with float16.
-
-```python
-from diffusers import StableDiffusionXLPipeline
-import torch
-
-pipe = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
-).to("cuda")
-
-# Run the attention ops without SDPA.
-pipe.unet.set_default_attn_processor()
-pipe.vae.set_default_attn_processor()
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipe(prompt, num_inference_steps=30).images[0]
-```
-
-bfloat16 reduces the latency from 7.36 seconds to 4.63 seconds.
-
-
-
-
-
-
-
-In our later experiments with float16, recent versions of torchao do not incur numerical problems from float16.
-
-
-
-Take a look at the [Speed up inference](../optimization/fp16) guide to learn more about running inference with reduced precision.
-
-## SDPA
-
-Attention blocks are intensive to run. But with PyTorch's [`scaled_dot_product_attention`](../optimization/torch2.0#scaled-dot-product-attention) function, it is a lot more efficient. This function is used by default in Diffusers so you don't need to make any changes to the code.
-
-```python
-from diffusers import StableDiffusionXLPipeline
-import torch
-
-pipe = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
-).to("cuda")
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipe(prompt, num_inference_steps=30).images[0]
-```
-
-Scaled dot product attention improves the latency from 4.63 seconds to 3.31 seconds.
-
-
-
-
-
-## torch.compile
-
-PyTorch 2 includes `torch.compile` which uses fast and optimized kernels. In Diffusers, the UNet and VAE are usually compiled because these are the most compute-intensive modules. First, configure a few compiler flags (refer to the [full list](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py) for more options):
-
-```python
-from diffusers import StableDiffusionXLPipeline
-import torch
-
-torch._inductor.config.conv_1x1_as_mm = True
-torch._inductor.config.coordinate_descent_tuning = True
-torch._inductor.config.epilogue_fusion = False
-torch._inductor.config.coordinate_descent_check_all_directions = True
-```
-
-It is also important to change the UNet and VAE's memory layout to "channels_last" when compiling them to ensure maximum speed.
-
-```python
-pipe.unet.to(memory_format=torch.channels_last)
-pipe.vae.to(memory_format=torch.channels_last)
-```
-
-Now compile and perform inference:
-
-```python
-# Compile the UNet and VAE.
-pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
-pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-
-# First call to `pipe` is slow, subsequent ones are faster.
-image = pipe(prompt, num_inference_steps=30).images[0]
-```
-
-`torch.compile` offers different backends and modes. For maximum inference speed, use "max-autotune" for the inductor backend. “max-autotune” uses CUDA graphs and optimizes the compilation graph specifically for latency. CUDA graphs greatly reduces the overhead of launching GPU operations by using a mechanism to launch multiple GPU operations through a single CPU operation.
-
-Using SDPA attention and compiling both the UNet and VAE cuts the latency from 3.31 seconds to 2.54 seconds.
-
-
-
-
-
-> [!TIP]
-> From PyTorch 2.3.1, you can control the caching behavior of `torch.compile()`. This is particularly beneficial for compilation modes like `"max-autotune"` which performs a grid-search over several compilation flags to find the optimal configuration. Learn more in the [Compile Time Caching in torch.compile](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) tutorial.
-
-### Prevent graph breaks
-
-Specifying `fullgraph=True` ensures there are no graph breaks in the underlying model to take full advantage of `torch.compile` without any performance degradation. For the UNet and VAE, this means changing how you access the return variables.
-
-```diff
-- latents = unet(
-- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
--).sample
-
-+ latents = unet(
-+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
-+)[0]
-```
-
-### Remove GPU sync after compilation
-
-During the iterative reverse diffusion process, the `step()` function is [called](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228) on the scheduler each time after the denoiser predicts the less noisy latent embeddings. Inside `step()`, the `sigmas` variable is [indexed](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476) which when placed on the GPU, causes a communication sync between the CPU and GPU. This introduces latency and it becomes more evident when the denoiser has already been compiled.
-
-But if the `sigmas` array always [stays on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240), the CPU and GPU sync doesn’t occur and you don't get any latency. In general, any CPU and GPU communication sync should be none or be kept to a bare minimum because it can impact inference latency.
-
-## Combine the attention block's projection matrices
-
-The UNet and VAE in SDXL use Transformer-like blocks which consists of attention blocks and feed-forward blocks.
-
-In an attention block, the input is projected into three sub-spaces using three different projection matrices – Q, K, and V. These projections are performed separately on the input. But we can horizontally combine the projection matrices into a single matrix and perform the projection in one step. This increases the size of the matrix multiplications of the input projections and improves the impact of quantization.
-
-You can combine the projection matrices with just a single line of code:
-
-```python
-pipe.fuse_qkv_projections()
-```
-
-This provides a minor improvement from 2.54 seconds to 2.52 seconds.
-
-
-
-
-
-
-
-Support for [`~StableDiffusionXLPipeline.fuse_qkv_projections`] is limited and experimental. It's not available for many non-Stable Diffusion pipelines such as [Kandinsky](../using-diffusers/kandinsky). You can refer to this [PR](https://github.com/huggingface/diffusers/pull/6179) to get an idea about how to enable this for the other pipelines.
-
-
-
-## Dynamic quantization
-
-You can also use the ultra-lightweight PyTorch quantization library, [torchao](https://github.com/pytorch-labs/ao) (commit SHA `54bcd5a10d0abbe7b0c045052029257099f83fd9`), to apply [dynamic int8 quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) to the UNet and VAE. Quantization adds additional conversion overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization). If the matmuls are too small, these techniques may degrade performance.
-
-First, configure all the compiler tags:
-
-```python
-from diffusers import StableDiffusionXLPipeline
-import torch
-
-# Notice the two new flags at the end.
-torch._inductor.config.conv_1x1_as_mm = True
-torch._inductor.config.coordinate_descent_tuning = True
-torch._inductor.config.epilogue_fusion = False
-torch._inductor.config.coordinate_descent_check_all_directions = True
-torch._inductor.config.force_fuse_int_mm_with_mul = True
-torch._inductor.config.use_mixed_mm = True
-```
-
-Certain linear layers in the UNet and VAE don’t benefit from dynamic int8 quantization. You can filter out those layers with the [`dynamic_quant_filter_fn`](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16) shown below.
-
-```python
-def dynamic_quant_filter_fn(mod, *args):
- return (
- isinstance(mod, torch.nn.Linear)
- and mod.in_features > 16
- and (mod.in_features, mod.out_features)
- not in [
- (1280, 640),
- (1920, 1280),
- (1920, 640),
- (2048, 1280),
- (2048, 2560),
- (2560, 1280),
- (256, 128),
- (2816, 1280),
- (320, 640),
- (512, 1536),
- (512, 256),
- (512, 512),
- (640, 1280),
- (640, 1920),
- (640, 320),
- (640, 5120),
- (640, 640),
- (960, 320),
- (960, 640),
- ]
- )
-
-
-def conv_filter_fn(mod, *args):
- return (
- isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]
- )
-```
-
-Finally, apply all the optimizations discussed so far:
-
-```python
-# SDPA + bfloat16.
-pipe = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
-).to("cuda")
-
-# Combine attention projection matrices.
-pipe.fuse_qkv_projections()
-
-# Change the memory layout.
-pipe.unet.to(memory_format=torch.channels_last)
-pipe.vae.to(memory_format=torch.channels_last)
-```
-
-Since dynamic quantization is only limited to the linear layers, convert the appropriate pointwise convolution layers into linear layers to maximize its benefit.
-
-```python
-from torchao import swap_conv2d_1x1_to_linear
-
-swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
-swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)
-```
-
-Apply dynamic quantization:
-
-```python
-from torchao import apply_dynamic_quant
-
-apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
-apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)
-```
-
-Finally, compile and perform inference:
-
-```python
-pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
-pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
-
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipe(prompt, num_inference_steps=30).images[0]
-```
-
-Applying dynamic quantization improves the latency from 2.52 seconds to 2.43 seconds.
-
-
-
-
diff --git a/docs/source/en/tutorials/inference_with_big_models.md b/docs/source/en/tutorials/inference_with_big_models.md
deleted file mode 100644
index 6af2e9bd3253..000000000000
--- a/docs/source/en/tutorials/inference_with_big_models.md
+++ /dev/null
@@ -1,139 +0,0 @@
-
-
-# Working with big models
-
-A modern diffusion model, like [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl), is not just a single model, but a collection of multiple models. SDXL has four different model-level components:
-
-* A variational autoencoder (VAE)
-* Two text encoders
-* A UNet for denoising
-
-Usually, the text encoders and the denoiser are much larger compared to the VAE.
-
-As models get bigger and better, it’s possible your model is so big that even a single copy won’t fit in memory. But that doesn’t mean it can’t be loaded. If you have more than one GPU, there is more memory available to store your model. In this case, it’s better to split your model checkpoint into several smaller *checkpoint shards*.
-
-When a text encoder checkpoint has multiple shards, like [T5-xxl for SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/tree/main/text_encoder_3), it is automatically handled by the [Transformers](https://huggingface.co/docs/transformers/index) library as it is a required dependency of Diffusers when using the [`StableDiffusion3Pipeline`]. More specifically, Transformers will automatically handle the loading of multiple shards within the requested model class and get it ready so that inference can be performed.
-
-The denoiser checkpoint can also have multiple shards and supports inference thanks to the [Accelerate](https://huggingface.co/docs/accelerate/index) library.
-
-> [!TIP]
-> Refer to the [Handling big models for inference](https://huggingface.co/docs/accelerate/main/en/concept_guides/big_model_inference) guide for general guidance when working with big models that are hard to fit into memory.
-
-For example, let's save a sharded checkpoint for the [SDXL UNet](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main/unet):
-
-```python
-from diffusers import UNet2DConditionModel
-
-unet = UNet2DConditionModel.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet"
-)
-unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
-```
-
-The size of the fp32 variant of the SDXL UNet checkpoint is ~10.4GB. Set the `max_shard_size` parameter to 5GB to create 3 shards. After saving, you can load them in [`StableDiffusionXLPipeline`]:
-
-```python
-from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
-import torch
-
-unet = UNet2DConditionModel.from_pretrained(
- "sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16
-)
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
-).to("cuda")
-
-image = pipeline("a cute dog running on the grass", num_inference_steps=30).images[0]
-image.save("dog.png")
-```
-
-If placing all the model-level components on the GPU at once is not feasible, use [`~DiffusionPipeline.enable_model_cpu_offload`] to help you:
-
-```diff
-- pipeline.to("cuda")
-+ pipeline.enable_model_cpu_offload()
-```
-
-In general, we recommend sharding when a checkpoint is more than 5GB (in fp32).
-
-## Device placement
-
-On distributed setups, you can run inference across multiple GPUs with Accelerate.
-
-> [!WARNING]
-> This feature is experimental and its APIs might change in the future.
-
-With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
-
-For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
-
-* it only works on a single GPU
-* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
-
-To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
-
-> [!WARNING]
-> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
-
-```diff
-from diffusers import DiffusionPipeline
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(
-- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
-+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
-)
-image = pipeline("a dog").images[0]
-image
-```
-
-You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
-
-```diff
-from diffusers import DiffusionPipeline
-import torch
-
-max_memory = {0:"1GB", 1:"1GB"}
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
- use_safetensors=True,
- device_map="balanced",
-+ max_memory=max_memory
-)
-image = pipeline("a dog").images[0]
-image
-```
-
-If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
-
-By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
-
-Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
-
-```py
-pipeline.reset_device_map()
-```
-
-Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
-
-```py
-print(pipeline.hf_device_map)
-```
-
-An example device map would look like so:
-
-
-```bash
-{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
-```
\ No newline at end of file
diff --git a/docs/source/en/tutorials/tutorial_overview.md b/docs/source/en/tutorials/tutorial_overview.md
deleted file mode 100644
index bb9cc3d354d4..000000000000
--- a/docs/source/en/tutorials/tutorial_overview.md
+++ /dev/null
@@ -1,23 +0,0 @@
-
-
-# Overview
-
-Welcome to 🧨 Diffusers! If you're new to diffusion models and generative AI, and want to learn more, then you've come to the right place. These beginner-friendly tutorials are designed to provide a gentle introduction to diffusion models and help you understand the library fundamentals - the core components and how 🧨 Diffusers is meant to be used.
-
-You'll learn how to use a pipeline for inference to rapidly generate things, and then deconstruct that pipeline to really understand how to use the library as a modular toolbox for building your own diffusion systems. In the next lesson, you'll learn how to train your own diffusion model to generate what you want.
-
-After completing the tutorials, you'll have gained the necessary skills to start exploring the library on your own and see how to use it for your own projects and applications.
-
-Feel free to join our community on [Discord](https://discord.com/invite/JfAtkvEtRb) or the [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) to connect and collaborate with other users and developers!
-
-Let's start diffusing! 🧨
diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index 33414a331ea7..7bdd2a1ee969 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -1,4 +1,4 @@
-
-[[open-in-colab]]
+# LoRA
-# Load LoRAs for inference
+[LoRA (Low-Rank Adaptation)](https://huggingface.co/papers/2106.09685) is a method for quickly training a model for a new task. It works by freezing the original model weights and adding a small number of *new* trainable parameters. This means it is significantly faster and cheaper to adapt an existing model to new tasks, such as generating images in a new style.
-There are many adapter types (with [LoRAs](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) being the most popular) trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images.
+LoRA checkpoints are typically only a couple hundred MBs in size, so they're very lightweight and easy to store. Load these smaller set of weights into an existing base model with [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and specify the file name.
-In this tutorial, you'll learn how to easily load and manage adapters for inference with the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers. You'll use LoRA as the main adapter technique, so you'll see the terms LoRA and adapter used interchangeably.
+
+
-Let's first install all the required libraries.
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/super-cereal-sdxl-lora",
+ weight_name="cereal_box_sdxl_v1.safetensors",
+ adapter_name="cereal"
+)
+pipeline("bears, pizza bites").images[0]
+```
-```bash
-!pip install -q transformers accelerate peft diffusers
+
+
+
+```py
+import torch
+from diffusers import LTXConditionPipeline
+from diffusers.utils import export_to_video, load_image
+
+pipeline = LTXConditionPipeline.from_pretrained(
+ "Lightricks/LTX-Video-0.9.5", torch_dtype=torch.bfloat16
+)
+
+pipeline.load_lora_weights(
+ "Lightricks/LTX-Video-Cakeify-LoRA",
+ weight_name="ltxv_095_cakeify_lora.safetensors",
+ adapter_name="cakeify"
+)
+pipeline.set_adapters("cakeify")
+
+# use "CAKEIFY" to trigger the LoRA
+prompt = "CAKEIFY a person using a knife to cut a cake shaped like a Pikachu plushie"
+image = load_image("https://huggingface.co/Lightricks/LTX-Video-Cakeify-LoRA/resolve/main/assets/images/pikachu.png")
+
+video = pipeline(
+ prompt=prompt,
+ image=image,
+ width=576,
+ height=576,
+ num_frames=161,
+ decode_timestep=0.03,
+ decode_noise_scale=0.025,
+ num_inference_steps=50,
+).frames[0]
+export_to_video(video, "output.mp4", fps=26)
```
-Now, load a pipeline with a [Stable Diffusion XL (SDXL)](../api/pipelines/stable_diffusion/stable_diffusion_xl) checkpoint:
+
+
-```python
-from diffusers import DiffusionPipeline
+The [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method is the preferred way to load LoRA weights into the UNet and text encoder because it can handle cases where:
+
+- the LoRA weights don't have separate UNet and text encoder identifiers
+- the LoRA weights have separate UNet and text encoder identifiers
+
+The [`~loaders.PeftAdapterMixin.load_lora_adapter`] method is used to directly load a LoRA adapter at the *model-level*, as long as the model is a Diffusers model that is a subclass of [`PeftAdapterMixin`]. It builds and prepares the necessary model configuration for the adapter. This method also loads the LoRA adapter into the UNet.
+
+For example, if you're only loading a LoRA into the UNet, [`~loaders.PeftAdapterMixin.load_lora_adapter`] ignores the text encoder keys. Use the `prefix` parameter to filter and load the appropriate state dicts, `"unet"` to load.
+
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.unet.load_lora_adapter(
+ "jbilcke-hf/sdxl-cinematic-1",
+ weight_name="pytorch_lora_weights.safetensors",
+ adapter_name="cinematic",
+ prefix="unet"
+)
+# use cnmt in the prompt to trigger the LoRA
+pipeline("A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration").images[0]
+```
+
+## torch.compile
+
+[torch.compile](../optimization/fp16#torchcompile) speeds up inference by compiling the PyTorch model to use optimized kernels. Before compiling, the LoRA weights need to be fused into the base model and unloaded first.
+
+```py
import torch
+from diffusers import DiffusionPipeline
-pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
-pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
+# load base model and LoRA
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/ikea-instructions-lora-sdxl",
+ weight_name="ikea_instructions_xl_v1_5.safetensors",
+ adapter_name="ikea"
+)
+
+# activate LoRA and set adapter weight
+pipeline.set_adapters("ikea", adapter_weights=0.7)
+
+# fuse LoRAs and unload weights
+pipeline.fuse_lora(adapter_names=["ikea"], lora_scale=1.0)
+pipeline.unload_lora_weights()
```
-Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which lets you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
+Typically, the UNet is compiled because its the most compute intensive component of the pipeline.
+
+```py
+pipeline.unet.to(memory_format=torch.channels_last)
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
-```python
-pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
+pipeline("A bowl of ramen shaped like a cute kawaii bear").images[0]
```
-Make sure to include the token `toy_face` in the prompt and then you can perform inference:
+Refer to the [hotswapping](#hotswapping) section to learn how to avoid recompilation when working with compiled models and multiple LoRAs.
-```python
-prompt = "toy_face of a hacker with a hoodie"
+## Weight scale
-lora_scale = 0.9
-image = pipe(
- prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
-).images[0]
-image
+The `scale` parameter is used to control how much of a LoRA to apply. A value of `0` is equivalent to only using the base model weights and a value of `1` is equivalent to fully using the LoRA.
+
+
+
+
+For simple use cases, you can pass `cross_attention_kwargs={"scale": 1.0}` to the pipeline.
+
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/super-cereal-sdxl-lora",
+ weight_name="cereal_box_sdxl_v1.safetensors",
+ adapter_name="cereal"
+)
+pipeline("bears, pizza bites", cross_attention_kwargs={"scale": 1.0}).images[0]
```
-
+
+
-With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
+> [!WARNING]
+> The [`~loaders.PeftAdapterMixin.set_adapters`] method only scales attention weights. If a LoRA has ResNets or down and upsamplers, these components keep a scale value of `1.0`.
-The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method:
+For finer control over each individual component of the UNet or text encoder, pass a dictionary instead. In the example below, the `"down"` block in the UNet is scaled by 0.9 and you can further specify in the `"up"` block the scales of the transformers in `"block_0"` and `"block_1"`. If a block like `"mid"` isn't specified, the default value `1.0` is used.
-```python
-pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
-pipe.set_adapters("pixel")
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/super-cereal-sdxl-lora",
+ weight_name="cereal_box_sdxl_v1.safetensors",
+ adapter_name="cereal"
+)
+scales = {
+ "text_encoder": 0.5,
+ "text_encoder_2": 0.5,
+ "unet": {
+ "down": 0.9,
+ "up": {
+ "block_0": 0.6,
+ "block_1": [0.4, 0.8, 1.0],
+ }
+ }
+}
+pipeline.set_adapters("cereal", scales)
+pipeline("bears, pizza bites").images[0]
```
-Make sure you include the token `pixel art` in your prompt to generate a pixel art image:
+
+
+
+### Scale scheduling
+
+Dynamically adjusting the LoRA scale during sampling gives you better control over the overall composition and layout because certain steps may benefit more from an increased or reduced scale.
+
+The [character LoRA](https://huggingface.co/alvarobartt/ghibli-characters-flux-lora) in the example below starts with a higher scale that gradually decays over the first 20 steps to establish the character generation. In the later steps, only a scale of 0.2 is applied to avoid adding too much of the LoRA features to other parts of the image the LoRA wasn't trained on.
-```python
-prompt = "a hacker with a hoodie, pixel art"
-image = pipe(
- prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
+```py
+import torch
+from diffusers import FluxPipeline
+
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+).to("cuda")
+
+pipelne.load_lora_weights("alvarobartt/ghibli-characters-flux-lora", "lora")
+
+num_inference_steps = 30
+lora_steps = 20
+lora_scales = torch.linspace(1.5, 0.7, lora_steps).tolist()
+lora_scales += [0.2] * (num_inference_steps - lora_steps + 1)
+
+pipeline.set_adapters("lora", lora_scales[0])
+
+def callback(pipeline: FluxPipeline, step: int, timestep: torch.LongTensor, callback_kwargs: dict):
+ pipeline.set_adapters("lora", lora_scales[step + 1])
+ return callback_kwargs
+
+prompt = """
+Ghibli style The Grinch, a mischievous green creature with a sly grin, peeking out from behind a snow-covered tree while plotting his antics,
+in a quaint snowy village decorated for the holidays, warm light glowing from cozy homes, with playful snowflakes dancing in the air
+"""
+pipeline(
+ prompt=prompt,
+ guidance_scale=3.0,
+ num_inference_steps=num_inference_steps,
+ generator=torch.Generator().manual_seed(42),
+ callback_on_step_end=callback,
).images[0]
-image
```
-
+## Hotswapping
+
+Hotswapping LoRAs is an efficient way to work with multiple LoRAs while avoiding accumulating memory from multiple calls to [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and in some cases, recompilation, if a model is compiled. This workflow requires a loaded LoRA because the new LoRA weights are swapped in place for the existing loaded LoRA.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
-
+# load base model and LoRAs
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/ikea-instructions-lora-sdxl",
+ weight_name="ikea_instructions_xl_v1_5.safetensors",
+ adapter_name="ikea"
+)
+```
-By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints.
+> [!WARNING]
+> Hotswapping is unsupported for LoRAs that target the text encoder.
-
+Set `hotswap=True` in [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] to swap the second LoRA. Use the `adapter_name` parameter to indicate which LoRA to swap (`default_0` is the default name).
-## Merge adapters
+```py
+pipeline.load_lora_weights(
+ "lordjia/by-feng-zikai",
+ hotswap=True,
+ adapter_name="ikea"
+)
+```
-You can also merge different adapter checkpoints for inference to blend their styles together.
+### Compiled models
-Once again, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
+For compiled models, use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation when hotswapping LoRAs. This method should be called *before* loading the first LoRA and `torch.compile` should be called *after* loading the first LoRA.
-```python
-pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
+> [!TIP]
+> The [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method isn't always necessary if the second LoRA targets the identical LoRA ranks and scales as the first LoRA.
+
+Within [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`], the `target_rank` parameter is important for setting the rank for all LoRA adapters. Setting it to `max_rank` sets it to the highest value. For LoRAs with different ranks, you set it to a higher rank value. The default rank value is 128.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+# load base model and LoRAs
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+# 1. enable_lora_hotswap
+pipeline.enable_lora_hotswap(target_rank=max_rank)
+pipeline.load_lora_weights(
+ "ostris/ikea-instructions-lora-sdxl",
+ weight_name="ikea_instructions_xl_v1_5.safetensors",
+ adapter_name="ikea"
+)
+# 2. torch.compile
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+
+# 3. hotswap
+pipeline.load_lora_weights(
+ "lordjia/by-feng-zikai",
+ hotswap=True,
+ adapter_name="ikea"
+)
```
-
+> [!TIP]
+> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
-LoRA checkpoints in the diffusion community are almost always obtained with [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth). DreamBooth training often relies on "trigger" words in the input text prompts in order for the generation results to look as expected. When you combine multiple LoRA checkpoints, it's important to ensure the trigger words for the corresponding LoRA checkpoints are present in the input text prompts.
+If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
-
+There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
-Remember to use the trigger words for [CiroN2022/toy-face](https://hf.co/CiroN2022/toy-face) and [nerijs/pixel-art-xl](https://hf.co/nerijs/pixel-art-xl) (these are found in their repositories) in the prompt to generate an image.
+
+Technical details of hotswapping
-```python
-prompt = "toy_face of a hacker with a hoodie, pixel art"
-image = pipe(
- prompt, num_inference_steps=30, cross_attention_kwargs={"scale": 1.0}, generator=torch.manual_seed(0)
-).images[0]
-image
+The [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method converts the LoRA scaling factor from floats to torch.tensors and pads the shape of the weights to the largest required shape to avoid reassigning the whole attribute when the data in the weights are replaced.
+
+This is why the `max_rank` argument is important. The results are unchanged even when the values are padded with zeros. Computation may be slower though depending on the padding size.
+
+Since no new LoRA attributes are added, each subsequent LoRA is only allowed to target the same layers, or subset of layers, the first LoRA targets. Choosing the LoRA loading order is important because if the LoRAs target disjoint layers, you may end up creating a dummy LoRA that targets the union of all target layers.
+
+For more implementation details, take a look at the [`hotswap.py`](https://github.com/huggingface/peft/blob/92d65cafa51c829484ad3d95cf71d09de57ff066/src/peft/utils/hotswap.py) file.
+
+
+
+## Merge
+
+The weights from each LoRA can be merged together to produce a blend of multiple existing styles. There are several methods for merging LoRAs, each of which differ in *how* the weights are merged (may affect generation quality).
+
+### set_adapters
+
+The [`~loaders.PeftAdapterMixin.set_adapters`] method merges LoRAs by concatenating their weighted matrices. Pass the LoRA names to [`~loaders.PeftAdapterMixin.set_adapters`] and use the `adapter_weights` parameter to control the scaling of each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, the output is an average of both LoRAs.
+
+> [!TIP]
+> The `"scale"` parameter determines how much of the merged LoRA to apply. See the [Weight scale](#weight-scale) section for more details.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/ikea-instructions-lora-sdxl",
+ weight_name="ikea_instructions_xl_v1_5.safetensors",
+ adapter_name="ikea"
+)
+pipeline.load_lora_weights(
+ "lordjia/by-feng-zikai",
+ weight_name="fengzikai_v1.0_XL.safetensors",
+ adapter_name="feng"
+)
+pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
+# use by Feng Zikai to activate the lordjia/by-feng-zikai LoRA
+pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", cross_attention_kwargs={"scale": 1.0}).images[0]
```
-
+
+
+
-Impressive! As you can see, the model generated an image that mixed the characteristics of both adapters.
+### add_weighted_adapter
> [!TIP]
-> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
+> This is an experimental method and you can refer to PEFTs [Model merging](https://huggingface.co/docs/peft/developer_guides/model_merging) for more details. Take a look at this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in the motivation and design behind this integration.
-To return to only using one adapter, use the [`~loaders.peft.PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
+The [`~peft.LoraModel.add_weighted_adapter`] method enables more efficient merging methods like [TIES](https://huggingface.co/papers/2306.01708) or [DARE](https://huggingface.co/papers/2311.03099). These merging methods remove redundant and potentially interfering parameters from merged models. Keep in mind the LoRA ranks need to have identical ranks to be merged.
-```python
-pipe.set_adapters("toy")
+Make sure the latest stable version of Diffusers and PEFT is installed.
-prompt = "toy_face of a hacker with a hoodie"
-lora_scale = 0.9
-image = pipe(
- prompt, num_inference_steps=30, cross_attention_kwargs={"scale": lora_scale}, generator=torch.manual_seed(0)
-).images[0]
-image
+```bash
+pip install -U -q diffusers peft
+```
+
+Load a UNET that corresponds to the LoRA UNet.
+
+```py
+import copy
+import torch
+from diffusers import AutoModel, DiffusionPipeline
+from peft import get_peft_model, LoraConfig, PeftModel
+
+unet = AutoModel.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+ subfolder="unet",
+).to("cuda")
```
-Or to disable all adapters entirely, use the [`~loaders.peft.PeftAdapterMixin.disable_lora`] method to return the base model.
+Load a pipeline, pass the UNet to it, and load a LoRA.
+
+```py
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ variant="fp16",
+ torch_dtype=torch.float16,
+ unet=unet
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/ikea-instructions-lora-sdxl",
+ weight_name="ikea_instructions_xl_v1_5.safetensors",
+ adapter_name="ikea"
+)
+```
-```python
-pipe.disable_lora()
+Create a [`~peft.PeftModel`] from the LoRA checkpoint by combining the first UNet you loaded and the LoRA UNet from the pipeline.
-prompt = "toy_face of a hacker with a hoodie"
-image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
-image
+```py
+sdxl_unet = copy.deepcopy(unet)
+ikea_peft_model = get_peft_model(
+ sdxl_unet,
+ pipeline.unet.peft_config["ikea"],
+ adapter_name="ikea"
+)
+
+original_state_dict = {f"base_model.model.{k}": v for k, v in pipeline.unet.state_dict().items()}
+ikea_peft_model.load_state_dict(original_state_dict, strict=True)
```
-
+> [!TIP]
+> You can save and reuse the `ikea_peft_model` by pushing it to the Hub as shown below.
+> ```py
+> ikea_peft_model.push_to_hub("ikea_peft_model", token=TOKEN)
+> ```
-### Customize adapters strength
+Repeat this process and create a [`~peft.PeftModel`] for the second LoRA.
-For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~loaders.peft.PeftAdapterMixin.set_adapters`].
+```py
+pipeline.delete_adapters("ikea")
+sdxl_unet.delete_adapters("ikea")
+
+pipeline.load_lora_weights(
+ "lordjia/by-feng-zikai",
+ weight_name="fengzikai_v1.0_XL.safetensors",
+ adapter_name="feng"
+)
+pipeline.set_adapters(adapter_names="feng")
+
+feng_peft_model = get_peft_model(
+ sdxl_unet,
+ pipeline.unet.peft_config["feng"],
+ adapter_name="feng"
+)
+
+original_state_dict = {f"base_model.model.{k}": v for k, v in pipe.unet.state_dict().items()}
+feng_peft_model.load_state_dict(original_state_dict, strict=True)
+```
+
+Load a base UNet model and load the adapters.
-For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
-```python
-pipe.enable_lora() # enable lora again, after we disabled it above
-prompt = "toy_face of a hacker with a hoodie, pixel art"
-adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} }
-pipe.set_adapters("pixel", adapter_weight_scales)
-image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
-image
+```py
+base_unet = AutoModel.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ variant="fp16",
+ subfolder="unet",
+).to("cuda")
+
+model = PeftModel.from_pretrained(
+ base_unet,
+ "stevhliu/ikea_peft_model",
+ use_safetensors=True,
+ subfolder="ikea",
+ adapter_name="ikea"
+)
+model.load_adapter(
+ "stevhliu/feng_peft_model",
+ use_safetensors=True,
+ subfolder="feng",
+ adapter_name="feng"
+)
```
-
+Merge the LoRAs with [`~peft.LoraModel.add_weighted_adapter`] and specify how you want to merge them with `combination_type`. The example below uses the `"dare_linear"` method (refer to this [blog post](https://huggingface.co/blog/peft_merging) to learn more about these merging methods), which randomly prunes some weights and then performs a weighted sum of the tensors based on the set weightage of each LoRA in `weights`.
+
+Activate the merged LoRAs with [`~loaders.PeftAdapterMixin.set_adapters`].
+
+```py
+model.add_weighted_adapter(
+ adapters=["ikea", "feng"],
+ combination_type="dare_linear",
+ weights=[1.0, 1.0],
+ adapter_name="ikea-feng"
+)
+model.set_adapters("ikea-feng")
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ unet=model,
+ variant="fp16",
+ torch_dtype=torch.float16,
+).to("cuda")
+pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai").images[0]
+```
+
+
+
+
+
+### fuse_lora
-Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image.
-```python
-adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} }
-pipe.set_adapters("pixel", adapter_weight_scales)
-image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
-image
+The [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method fuses the LoRA weights directly with the original UNet and text encoder weights of the underlying model. This reduces the overhead of loading the underlying model for each LoRA because it only loads the model once, which lowers memory usage and increases inference speed.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/ikea-instructions-lora-sdxl",
+ weight_name="ikea_instructions_xl_v1_5.safetensors",
+ adapter_name="ikea"
+)
+pipeline.load_lora_weights(
+ "lordjia/by-feng-zikai",
+ weight_name="fengzikai_v1.0_XL.safetensors",
+ adapter_name="feng"
+)
+pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
```
-
+Call [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] to fuse them. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make this adjustment now because passing `scale` to `cross_attention_kwargs` won't work in the pipeline.
-```python
-adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} }
-pipe.set_adapters("pixel", adapter_weight_scales)
-image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
-image
+```py
+pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
```
-
+Unload the LoRA weights since they're already fused with the underlying model. Save the fused pipeline with either [`~DiffusionPipeline.save_pretrained`] to save it locally or [`~PushToHubMixin.push_to_hub`] to save it to the Hub.
-Looks cool!
+
+
-This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters.
-```python
-adapter_weight_scales_toy = 0.5
-adapter_weight_scales_pixel = {
- "unet": {
- "down": 0.9, # all transformers in the down-part will use scale 0.9
- # "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0
- "up": {
- "block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
- "block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
- }
- }
-}
-pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel])
-image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
-image
+```py
+pipeline.unload_lora_weights()
+pipeline.save_pretrained("path/to/fused-pipeline")
```
-
+
+
+
+```py
+pipeline.unload_lora_weights()
+pipeline.push_to_hub("fused-ikea-feng")
+```
-## Manage adapters
+
+
-You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.StableDiffusionLoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
+The fused pipeline can now be quickly loaded for inference without requiring each LoRA to be separately loaded.
```py
-active_adapters = pipe.get_active_adapters()
-active_adapters
-["toy", "pixel"]
+pipeline = DiffusionPipeline.from_pretrained(
+ "username/fused-ikea-feng", torch_dtype=torch.float16,
+).to("cuda")
+pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai").images[0]
```
-You can also get the active adapters of each pipeline component with [`~diffusers.loaders.StableDiffusionLoraLoaderMixin.get_list_adapters`]:
+Use [`~loaders.LoraLoaderMixin.unfuse_lora`] to restore the underlying models weights, for example, if you want to use a different `lora_scale` value. You can only unfuse if there is a single LoRA fused. For example, it won't work with the pipeline from above because there are multiple fused LoRAs. In these cases, you'll need to reload the entire model.
```py
-list_adapters_component_wise = pipe.get_list_adapters()
-list_adapters_component_wise
-{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
+pipeline.unfuse_lora()
```
-The [`~loaders.peft.PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
+
+
+
+
+## Manage
+
+Diffusers provides several methods to help you manage working with LoRAs. These methods can be especially useful if you're working with multiple LoRAs.
+
+### set_adapters
+
+[`~loaders.PeftAdapterMixin.set_adapters`] also activates the current LoRA to use if there are multiple active LoRAs. This allows you to switch between different LoRAs by specifying their name.
```py
-pipe.delete_adapters("toy")
-pipe.get_active_adapters()
-["pixel"]
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_lora_weights(
+ "ostris/ikea-instructions-lora-sdxl",
+ weight_name="ikea_instructions_xl_v1_5.safetensors",
+ adapter_name="ikea"
+)
+pipeline.load_lora_weights(
+ "lordjia/by-feng-zikai",
+ weight_name="fengzikai_v1.0_XL.safetensors",
+ adapter_name="feng"
+)
+# activates the feng LoRA instead of the ikea LoRA
+pipeline.set_adapters("feng")
+```
+
+### save_lora_adapter
+
+Save an adapter with [`~loaders.PeftAdapterMixin.save_lora_adapter`].
+
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.unet.load_lora_adapter(
+ "jbilcke-hf/sdxl-cinematic-1",
+ weight_name="pytorch_lora_weights.safetensors",
+ adapter_name="cinematic"
+ prefix="unet"
+)
+pipeline.save_lora_adapter("path/to/save", adapter_name="cinematic")
+```
+
+### unload_lora_weights
+
+The [`~loaders.lora_base.LoraBaseMixin.unload_lora_weights`] method unloads any LoRA weights in the pipeline to restore the underlying model weights.
+
+```py
+pipeline.unload_lora_weights()
+```
+
+### disable_lora
+
+The [`~loaders.PeftAdapterMixin.disable_lora`] method disables all LoRAs (but they're still kept on the pipeline) and restores the pipeline to the underlying model weights.
+
+```py
+pipeline.disable_lora()
```
-## PeftInputAutocastDisableHook
+### get_active_adapters
+
+The [`~loaders.lora_base.LoraBaseMixin.get_active_adapters`] method returns a list of active LoRAs attached to a pipeline.
+
+```py
+pipeline.get_active_adapters()
+["cereal", "ikea"]
+```
+
+### get_list_adapters
+
+The [`~loaders.lora_base.LoraBaseMixin.get_list_adapters`] method returns the active LoRAs for each component in the pipeline.
+
+```py
+pipeline.get_list_adapters()
+{"unet": ["cereal", "ikea"], "text_encoder_2": ["cereal"]}
+```
+
+### delete_adapters
+
+The [`~loaders.PeftAdapterMixin.delete_adapters`] method completely removes a LoRA and its layers from a model.
+
+```py
+pipeline.delete_adapters("ikea")
+```
+
+## Resources
+
+Browse the [LoRA Studio](https://lorastudio.co/models) for different LoRAs to use or you can upload your favorite LoRAs from Civitai to the Hub with the Space below.
+
+
+
+You can find additional LoRAs in the [FLUX LoRA the Explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer) and [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) Spaces.
-[[autodoc]] hooks.layerwise_casting.PeftInputAutocastDisableHook
+Check out the [Fast LoRA inference for Flux with Diffusers and PEFT](https://huggingface.co/blog/lora-fast) blog post to learn how to optimize LoRA inference with methods like FlashAttention-3 and fp8 quantization.
diff --git a/docs/source/en/using-diffusers/automodel.md b/docs/source/en/using-diffusers/automodel.md
new file mode 100644
index 000000000000..957cbd17e3f7
--- /dev/null
+++ b/docs/source/en/using-diffusers/automodel.md
@@ -0,0 +1,46 @@
+
+
+# AutoModel
+
+The [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries.
+
+The example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from.
+
+```py
+import torch
+from diffusers import AutoModel, DiffusionPipeline
+
+transformer = AutoModel.from_pretrained(
+ "Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+
+text_encoder = AutoModel.from_pretrained(
+ "Qwen/Qwen-Image", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+```
+
+[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
+
+```py
+import torch
+from diffusers import AutoModel
+
+transformer = AutoModel.from_pretrained(
+ "custom/custom-transformer-model", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
+)
+```
+
+If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
+
+> [!NOTE]
+> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/batched_inference.md b/docs/source/en/using-diffusers/batched_inference.md
new file mode 100644
index 000000000000..cdb16ac1212b
--- /dev/null
+++ b/docs/source/en/using-diffusers/batched_inference.md
@@ -0,0 +1,173 @@
+
+
+# Batch inference
+
+Batch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU.
+
+The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.
+
+For text-to-image, pass a list of prompts to the pipeline and for image-to-image, pass a list of images and prompts to the pipeline. The example below demonstrates batched text-to-image inference.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+
+prompts = [
+ "Cinematic shot of a cozy coffee shop interior, warm pastel light streaming through a window where a cat rests. Shallow depth of field, glowing cups in soft focus, dreamy lofi-inspired mood, nostalgic tones, framed like a quiet film scene.",
+ "Polaroid-style photograph of a cozy coffee shop interior, bathed in warm pastel light. A cat sits on the windowsill near steaming mugs. Soft, slightly faded tones and dreamy blur evoke nostalgia, a lofi mood, and the intimate, imperfect charm of instant film.",
+ "Soft watercolor illustration of a cozy coffee shop interior, pastel washes of color filling the space. A cat rests peacefully on the windowsill as warm light glows through. Gentle brushstrokes create a dreamy, lofi-inspired atmosphere with whimsical textures and nostalgic calm.",
+ "Isometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the space as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the nostalgic, lofi-inspired game aesthetic."
+]
+
+images = pipeline(
+ prompt=prompts,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+
+prompt="""
+Isometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the
+space as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the
+nostalgic, lofi-inspired game aesthetic.
+"""
+
+images = pipeline(
+ prompt=prompt,
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+images = pipeline(
+ prompt=prompts,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 4, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+
+## Deterministic generation
+
+Enable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.
+
+> [!TIP]
+> Refer to the [Reproducibility](./reusing_seeds) docs to learn more about deterministic algorithms and the `Generator` object.
+
+Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
+
+```py
+generator = [torch.Generator(device="cuda").manual_seed(0)] * 3
+```
+
+Pass the `generator` to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+
+generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)]
+prompts = [
+ "Cinematic shot of a cozy coffee shop interior, warm pastel light streaming through a window where a cat rests. Shallow depth of field, glowing cups in soft focus, dreamy lofi-inspired mood, nostalgic tones, framed like a quiet film scene.",
+ "Polaroid-style photograph of a cozy coffee shop interior, bathed in warm pastel light. A cat sits on the windowsill near steaming mugs. Soft, slightly faded tones and dreamy blur evoke nostalgia, a lofi mood, and the intimate, imperfect charm of instant film.",
+ "Soft watercolor illustration of a cozy coffee shop interior, pastel washes of color filling the space. A cat rests peacefully on the windowsill as warm light glows through. Gentle brushstrokes create a dreamy, lofi-inspired atmosphere with whimsical textures and nostalgic calm.",
+ "Isometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the space as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the nostalgic, lofi-inspired game aesthetic."
+]
+
+images = pipeline(
+ prompt=prompts,
+ generator=generator
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+You can use this to select an image associated with a seed and iteratively improve on it by crafting a more detailed prompt.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md
index 2462fed1a3cf..60b839805ff2 100644
--- a/docs/source/en/using-diffusers/callback.md
+++ b/docs/source/en/using-diffusers/callback.md
@@ -1,4 +1,4 @@
-
-# CogVideoX
-
-CogVideoX is a text-to-video generation model focused on creating more coherent videos aligned with a prompt. It achieves this using several methods.
-
-- a 3D variational autoencoder that compresses videos spatially and temporally, improving compression rate and video accuracy.
-
-- an expert transformer block to help align text and video, and a 3D full attention module for capturing and creating spatially and temporally accurate videos.
-
-
-
-## Load model checkpoints
-Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
-
-
-```py
-from diffusers import CogVideoXPipeline, CogVideoXImageToVideoPipeline
-pipe = CogVideoXPipeline.from_pretrained(
- "THUDM/CogVideoX-2b",
- torch_dtype=torch.float16
-)
-
-pipe = CogVideoXImageToVideoPipeline.from_pretrained(
- "THUDM/CogVideoX-5b-I2V",
- torch_dtype=torch.bfloat16
-)
-
-```
-
-## Text-to-Video
-For text-to-video, pass a text prompt. By default, CogVideoX generates a 720x480 video for the best results.
-
-```py
-import torch
-from diffusers import CogVideoXPipeline
-from diffusers.utils import export_to_video
-
-prompt = "An elderly gentleman, with a serene expression, sits at the water's edge, a steaming cup of tea by his side. He is engrossed in his artwork, brush in hand, as he renders an oil painting on a canvas that's propped up against a small, weathered table. The sea breeze whispers through his silver hair, gently billowing his loose-fitting white shirt, while the salty air adds an intangible element to his masterpiece in progress. The scene is one of tranquility and inspiration, with the artist's canvas capturing the vibrant hues of the setting sun reflecting off the tranquil sea."
-
-pipe = CogVideoXPipeline.from_pretrained(
- "THUDM/CogVideoX-5b",
- torch_dtype=torch.bfloat16
-)
-
-pipe.enable_model_cpu_offload()
-pipe.vae.enable_tiling()
-
-video = pipe(
- prompt=prompt,
- num_videos_per_prompt=1,
- num_inference_steps=50,
- num_frames=49,
- guidance_scale=6,
- generator=torch.Generator(device="cuda").manual_seed(42),
-).frames[0]
-
-export_to_video(video, "output.mp4", fps=8)
-
-```
-
-
-
-
-
-
-
-## Image-to-Video
-
-
-You'll use the [THUDM/CogVideoX-5b-I2V](https://huggingface.co/THUDM/CogVideoX-5b-I2V) checkpoint for this guide.
-
-```py
-import torch
-from diffusers import CogVideoXImageToVideoPipeline
-from diffusers.utils import export_to_video, load_image
-
-prompt = "A vast, shimmering ocean flows gracefully under a twilight sky, its waves undulating in a mesmerizing dance of blues and greens. The surface glints with the last rays of the setting sun, casting golden highlights that ripple across the water. Seagulls soar above, their cries blending with the gentle roar of the waves. The horizon stretches infinitely, where the ocean meets the sky in a seamless blend of hues. Close-ups reveal the intricate patterns of the waves, capturing the fluidity and dynamic beauty of the sea in motion."
-image = load_image(image="cogvideox_rocket.png")
-pipe = CogVideoXImageToVideoPipeline.from_pretrained(
- "THUDM/CogVideoX-5b-I2V",
- torch_dtype=torch.bfloat16
-)
-
-pipe.vae.enable_tiling()
-pipe.vae.enable_slicing()
-
-video = pipe(
- prompt=prompt,
- image=image,
- num_videos_per_prompt=1,
- num_inference_steps=50,
- num_frames=49,
- guidance_scale=6,
- generator=torch.Generator(device="cuda").manual_seed(42),
-).frames[0]
-
-export_to_video(video, "output.mp4", fps=8)
-```
-
-
-
-
-
initial image
-
-
-
-
generated video
-
-
-
diff --git a/docs/source/en/using-diffusers/conditional_image_generation.md b/docs/source/en/using-diffusers/conditional_image_generation.md
index b58b3b74b91a..eb75b6b8a8b1 100644
--- a/docs/source/en/using-diffusers/conditional_image_generation.md
+++ b/docs/source/en/using-diffusers/conditional_image_generation.md
@@ -1,4 +1,4 @@
-
-# Load community pipelines and components
-
[[open-in-colab]]
-## Community pipelines
-
-> [!TIP] Take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down.
-
-Community pipelines are any [`DiffusionPipeline`] class that are different from the original paper implementation (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.
-
-There are many cool community pipelines like [Marigold Depth Estimation](https://github.com/huggingface/diffusers/tree/main/examples/community#marigold-depth-estimation) or [InstantID](https://github.com/huggingface/diffusers/tree/main/examples/community#instantid-pipeline), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
-
-There are two types of community pipelines, those stored on the Hugging Face Hub and those stored on Diffusers GitHub repository. Hub pipelines are completely customizable (scheduler, models, pipeline code, etc.) while Diffusers GitHub pipelines are only limited to custom pipeline code.
-
-| | GitHub community pipeline | HF Hub community pipeline |
-|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
-| usage | same | same |
-| review process | open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging; may be slower | upload directly to a Hub repository without any review; this is the fastest workflow |
-| visibility | included in the official Diffusers repository and documentation | included on your HF Hub profile and relies on your own usage/promotion to gain visibility |
-
-
-
-### Load from a local file
-
-Community pipelines can also be loaded from a local file if you pass a file path instead. The path to the passed directory must contain a pipeline.py file that contains the pipeline class.
-
-```py
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- custom_pipeline="./path/to/pipeline_directory/",
- clip_model=clip_model,
- feature_extractor=feature_extractor,
- use_safetensors=True,
-)
-```
-
-### Load from a specific version
-
-By default, community pipelines are loaded from the latest stable version of Diffusers. To load a community pipeline from another version, use the `custom_revision` parameter.
-
-
-
-
-For example, to load from the main branch:
-
-```py
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- custom_pipeline="clip_guided_stable_diffusion",
- custom_revision="main",
- clip_model=clip_model,
- feature_extractor=feature_extractor,
- use_safetensors=True,
-)
-```
-
-
-
-
-For example, to load from a previous version of Diffusers like v0.25.0:
-```py
pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- custom_pipeline="clip_guided_stable_diffusion",
- custom_revision="v0.25.0",
- clip_model=clip_model,
- feature_extractor=feature_extractor,
- use_safetensors=True,
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ custom_pipeline="pipeline_stable_diffusion_3_instruct_pix2pix",
+ torch_dtype=torch.float16,
+ device_map="cuda"
)
```
-
-
-
-### Load with from_pipe
-
-Community pipelines can also be loaded with the [`~DiffusionPipeline.from_pipe`] method which allows you to load and reuse multiple pipelines without any additional memory overhead (learn more in the [Reuse a pipeline](./loading#reuse-a-pipeline) guide). The memory requirement is determined by the largest single pipeline loaded.
-
-For example, let's load a community pipeline that supports [long prompts with weighting](https://github.com/huggingface/diffusers/tree/main/examples/community#long-prompt-weighting-stable-diffusion) from a Stable Diffusion pipeline.
+Add the `custom_revision` argument to [`~DiffusionPipeline.from_pretrained`] to load a community pipeline from a specific version (for example, `v0.30.0` or `main`). By default, community pipelines are loaded from the latest stable version of Diffusers.
```py
import torch
from diffusers import DiffusionPipeline
-pipe_sd = DiffusionPipeline.from_pretrained("emilianJR/CyberRealistic_V3", torch_dtype=torch.float16)
-pipe_sd.to("cuda")
-# load long prompt weighting pipeline
-pipe_lpw = DiffusionPipeline.from_pipe(
- pipe_sd,
- custom_pipeline="lpw_stable_diffusion",
-).to("cuda")
-
-prompt = "cat, hiding in the leaves, ((rain)), zazie rainyday, beautiful eyes, macro shot, colorful details, natural lighting, amazing composition, subsurface scattering, amazing textures, filmic, soft light, ultra-detailed eyes, intricate details, detailed texture, light source contrast, dramatic shadows, cinematic light, depth of field, film grain, noise, dark background, hyperrealistic dslr film still, dim volumetric cinematic lighting"
-neg_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
-generator = torch.Generator(device="cpu").manual_seed(20)
-out_lpw = pipe_lpw(
- prompt,
- negative_prompt=neg_prompt,
- width=512,
- height=512,
- max_embeddings_multiples=3,
- num_inference_steps=50,
- generator=generator,
- ).images[0]
-out_lpw
-```
-
-
-
-
-
Stable Diffusion with long prompt weighting
-
-
-
-
Stable Diffusion
-
-
-
-## Example community pipelines
-
-Community pipelines are a really fun and creative way to extend the capabilities of the original pipeline with new and unique features. You can find all community pipelines in the [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) folder with inference and training examples for how to use them.
-
-This section showcases a couple of the community pipelines and hopefully it'll inspire you to create your own (feel free to open a PR for your community pipeline and ping us for a review)!
-
-> [!TIP]
-> The [`~DiffusionPipeline.from_pipe`] method is particularly useful for loading community pipelines because many of them don't have pretrained weights and add a feature on top of an existing pipeline like Stable Diffusion or Stable Diffusion XL. You can learn more about the [`~DiffusionPipeline.from_pipe`] method in the [Load with from_pipe](custom_pipeline_overview#load-with-from_pipe) section.
-
-
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline
-## Community components
+ pipeline_sd = DiffusionPipeline.from_pretrained("emilianJR/CyberRealistic_V3", torch_dtype=torch.float16, device_map="cuda")
+ pipeline_lpw = DiffusionPipeline.from_pipe(
+ pipeline_sd, custom_pipeline="lpw_stable_diffusion", device_map="cuda"
+ )
+ ```
-Community components allow users to build pipelines that may have customized components that are not a part of Diffusers. If your pipeline has custom components that Diffusers doesn't already support, you need to provide their implementations as Python modules. These customized components could be a VAE, UNet, and scheduler. In most cases, the text encoder is imported from the Transformers library. The pipeline code itself can also be customized.
+ The [`~DiffusionPipeline.from_pipe`] method is especially useful for loading community pipelines because many of them don't have pretrained weights. Community pipelines generally add a feature on top of an existing pipeline.
-This section shows how users should use community components to build a community pipeline.
-
-You'll use the [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) pipeline checkpoint as an example.
-
-1. Import and load the text encoder from Transformers:
+## Community components
-```python
-from transformers import T5Tokenizer, T5EncoderModel
+Community components let users build pipelines with custom transformers, UNets, VAEs, and schedulers not supported by Diffusers. These components require Python module implementations.
-pipe_id = "showlab/show-1-base"
-tokenizer = T5Tokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
-text_encoder = T5EncoderModel.from_pretrained(pipe_id, subfolder="text_encoder")
-```
+This section shows how users can use community components to build a community pipeline using [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) as an example.
-2. Load a scheduler:
+1. Load the required components, the scheduler and image processor. The text encoder is generally imported from [Transformers](https://huggingface.co/docs/transformers/index).
```python
+from transformers import T5Tokenizer, T5EncoderModel, CLIPImageProcessor
from diffusers import DPMSolverMultistepScheduler
+pipeline_id = "showlab/show-1-base"
+tokenizer = T5Tokenizer.from_pretrained(pipeline_id, subfolder="tokenizer")
+text_encoder = T5EncoderModel.from_pretrained(pipeline_id, subfolder="text_encoder")
scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="scheduler")
-```
-
-3. Load an image processor:
-
-```python
-from transformers import CLIPImageProcessor
-
feature_extractor = CLIPImageProcessor.from_pretrained(pipe_id, subfolder="feature_extractor")
```
-
-
-In steps 4 and 5, the custom [UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) and [pipeline](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) implementation must match the format shown in their files for this example to work.
-
-
-
-4. Now you'll load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py), which in this example, has already been implemented in [showone_unet_3d_condition.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) for your convenience. You'll notice the [`UNet3DConditionModel`] class name is changed to `ShowOneUNet3DConditionModel` because [`UNet3DConditionModel`] already exists in Diffusers. Any components needed for the `ShowOneUNet3DConditionModel` class should be placed in showone_unet_3d_condition.py.
+> [!WARNING]
+> In steps 2 and 3, the custom [UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) and [pipeline](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) implementation must match the format shown in their files for this example to work.
- Once this is done, you can initialize the UNet:
+2. Load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) which is already implemented in [showone_unet_3d_condition.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py). The [`UNet3DConditionModel`] class name is renamed to the custom implementation, `ShowOneUNet3DConditionModel`, because [`UNet3DConditionModel`] already exists in Diffusers. Any components required for `ShowOneUNet3DConditionModel` class should be placed in `showone_unet_3d_condition.py`.
- ```python
- from showone_unet_3d_condition import ShowOneUNet3DConditionModel
+```python
+from showone_unet_3d_condition import ShowOneUNet3DConditionModel
- unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")
- ```
+unet = ShowOneUNet3DConditionModel.from_pretrained(pipeline_id, subfolder="unet")
+```
-5. Finally, you'll load the custom pipeline code. For this example, it has already been created for you in [pipeline_t2v_base_pixel.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Just like the custom UNet, any code needed for the custom pipeline to work should go in pipeline_t2v_base_pixel.py.
+3. Load the custom pipeline code (already implemented in [pipeline_t2v_base_pixel.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py)). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Like the custom UNet, any code required for `TextToVideIFPipeline` should be placed in `pipeline_t2v_base_pixel.py`.
-Once everything is in place, you can initialize the `TextToVideoIFPipeline` with the `ShowOneUNet3DConditionModel`:
+Initialize `TextToVideoIFPipeline` with `ShowOneUNet3DConditionModel`.
```python
-from pipeline_t2v_base_pixel import TextToVideoIFPipeline
import torch
+from pipeline_t2v_base_pixel import TextToVideoIFPipeline
pipeline = TextToVideoIFPipeline(
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
- feature_extractor=feature_extractor
+ feature_extractor=feature_extractor,
+ device_map="cuda",
+ torch_dtype=torch.float16
)
-pipeline = pipeline.to(device="cuda")
-pipeline.torch_dtype = torch.float16
```
-Push the pipeline to the Hub to share with the community!
+4. Push the pipeline to the Hub to share with the community.
```python
pipeline.push_to_hub("custom-t2v-pipeline")
```
-After the pipeline is successfully pushed, you need to make a few changes:
+After the pipeline is successfully pushed, make the following changes.
-1. Change the `_class_name` attribute in [model_index.json](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `"pipeline_t2v_base_pixel"` and `"TextToVideoIFPipeline"`.
-2. Upload `showone_unet_3d_condition.py` to the [unet](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) subfolder.
-3. Upload `pipeline_t2v_base_pixel.py` to the pipeline [repository](https://huggingface.co/sayakpaul/show-1-base-with-code/tree/main).
+- Change the `_class_name` attribute in [model_index.json](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `"pipeline_t2v_base_pixel"` and `"TextToVideoIFPipeline"`.
+- Upload `showone_unet_3d_condition.py` to the [unet](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) subfolder.
+- Upload `pipeline_t2v_base_pixel.py` to the pipeline [repository](https://huggingface.co/sayakpaul/show-1-base-with-code/tree/main).
To run inference, add the `trust_remote_code` argument while initializing the pipeline to handle all the "magic" behind the scenes.
-> [!WARNING]
-> As an additional precaution with `trust_remote_code=True`, we strongly encourage you to pass a commit hash to the `revision` parameter in [`~DiffusionPipeline.from_pretrained`] to make sure the code hasn't been updated with some malicious new lines of code (unless you fully trust the model owners).
-
```python
-from diffusers import DiffusionPipeline
import torch
+from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"/", trust_remote_code=True, torch_dtype=torch.float16
-).to("cuda")
-
-prompt = "hello"
-
-# Text embeds
-prompt_embeds, negative_embeds = pipeline.encode_prompt(prompt)
-
-# Keyframes generation (8x64x40, 2fps)
-video_frames = pipeline(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_embeds,
- num_frames=8,
- height=40,
- width=64,
- num_inference_steps=2,
- guidance_scale=9.0,
- output_type="pt"
-).frames
+)
```
-As an additional reference, take a look at the repository structure of [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/) which also uses the `trust_remote_code` feature.
+> [!WARNING]
+> As an additional precaution with `trust_remote_code=True`, we strongly encourage passing a commit hash to the `revision` argument in [`~DiffusionPipeline.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
-```python
-from diffusers import DiffusionPipeline
-import torch
+## Resources
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/japanese-stable-diffusion-xl", trust_remote_code=True
-)
-pipeline.to("cuda")
-```
+- Take a look at Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down.
+- Check out the [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/) repository for an additional example of a community pipeline that also uses the `trust_remote_code` feature.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/depth2img.md b/docs/source/en/using-diffusers/depth2img.md
index c0929727ff5f..7e04b04520fa 100644
--- a/docs/source/en/using-diffusers/depth2img.md
+++ b/docs/source/en/using-diffusers/depth2img.md
@@ -1,4 +1,4 @@
-
+
+# DreamBooth
+
+[DreamBooth](https://huggingface.co/papers/2208.12242) is a method for generating personalized images of a specific instance. It works by fine-tuning the model on 3-5 images of the subject (for example, a cat) that is associated with a unique identifier (`sks cat`). This allows you to use `sks cat` in your prompt to trigger the model to generate images of your cat in different settings, lighting, poses, and styles.
+
+DreamBooth checkpoints are typically a few GBs in size because it contains the full model weights.
+
+Load the DreamBooth checkpoint with [`~DiffusionPipeline.from_pretrained`] and include the unique identifier in the prompt to activate its generation.
+
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "sd-dreambooth-library/herge-style",
+ torch_dtype=torch.float16
+).to("cuda")
+prompt = "A cute sks herge_style brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
+pipeline(prompt).images[0]
+```
+
+
+
+
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/image_quality.md b/docs/source/en/using-diffusers/image_quality.md
index 960a84105674..29ce483d5ecc 100644
--- a/docs/source/en/using-diffusers/image_quality.md
+++ b/docs/source/en/using-diffusers/image_quality.md
@@ -1,4 +1,4 @@
-
-# Controlling image quality
-
-The components of a diffusion model, like the UNet and scheduler, can be optimized to improve the quality of generated images leading to better details. These techniques are especially useful if you don't have the resources to simply use a larger model for inference. You can enable these techniques during inference without any additional training.
-
-This guide will show you how to turn these techniques on in your pipeline and how to configure them to improve the quality of your generated images.
-
-## Details
+# FreeU
[FreeU](https://hf.co/papers/2309.11497) improves image details by rebalancing the UNet's backbone and skip connection weights. The skip connections can cause the model to overlook some of the backbone semantics which may lead to unnatural image details in the generated image. This technique does not require any additional training and can be applied on the fly during inference for tasks like image-to-image and text-to-video.
@@ -139,7 +133,7 @@ export_to_video(video_frames, "teddy_bear.mp4", fps=10)
-Call the [`pipelines.StableDiffusionMixin.disable_freeu`] method to disable FreeU.
+Call the [`~pipelines.StableDiffusionMixin.disable_freeu`] method to disable FreeU.
```py
pipeline.disable_freeu()
diff --git a/docs/source/en/using-diffusers/img2img.md b/docs/source/en/using-diffusers/img2img.md
index d9902081fde5..ef00bf7f9b2b 100644
--- a/docs/source/en/using-diffusers/img2img.md
+++ b/docs/source/en/using-diffusers/img2img.md
@@ -1,4 +1,4 @@
-
-# Load pipelines
-
[[open-in-colab]]
-Diffusion systems consist of multiple components like parameterized models and schedulers that interact in complex ways. That is why we designed the [`DiffusionPipeline`] to wrap the complexity of the entire diffusion system into an easy-to-use API. At the same time, the [`DiffusionPipeline`] is entirely customizable so you can modify each component to build a diffusion system for your use case.
-
-This guide will show you how to load:
-
-- pipelines from the Hub and locally
-- different components into a pipeline
-- multiple pipelines without increasing memory usage
-- checkpoint variants such as different floating point types or non-exponential mean averaged (EMA) weights
+# DiffusionPipeline
-## Load a pipeline
-
-> [!TIP]
-> Skip to the [DiffusionPipeline explained](#diffusionpipeline-explained) section if you're interested in an explanation about how the [`DiffusionPipeline`] class works.
+Diffusion models consists of multiple components like UNets or diffusion transformers (DiTs), text encoders, variational autoencoders (VAEs), and schedulers. The [`DiffusionPipeline`] wraps all of these components into a single easy-to-use API without giving up the flexibility to modify it's components.
-There are two ways to load a pipeline for a task:
+This guide will show you how to load a [`DiffusionPipeline`].
-1. Load the generic [`DiffusionPipeline`] class and allow it to automatically detect the correct pipeline class from the checkpoint.
-2. Load a specific pipeline class for a specific task.
+## Loading a pipeline
-
-
+[`DiffusionPipeline`] is a base pipeline class that automatically selects and returns an instance of a model's pipeline subclass, like [`QwenImagePipeline`], by scanning the `model_index.json` file for the class name.
-The [`DiffusionPipeline`] class is a simple and generic way to load the latest trending diffusion model from the [Hub](https://huggingface.co/models?library=diffusers&sort=trending). It uses the [`~DiffusionPipeline.from_pretrained`] method to automatically detect the correct pipeline class for a task from the checkpoint, downloads and caches all the required configuration and weight files, and returns a pipeline ready for inference.
-
-```python
-from diffusers import DiffusionPipeline
-
-pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-```
-
-This same checkpoint can also be used for an image-to-image task. The [`DiffusionPipeline`] class can handle any task as long as you provide the appropriate inputs. For example, for an image-to-image task, you need to pass an initial image to the pipeline.
+Pass a model id to [`~DiffusionPipeline.from_pretrained`] to load a pipeline.
```py
+import torch
from diffusers import DiffusionPipeline
-pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png")
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=init_image).images[0]
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
```
-
-
-
-Checkpoints can be loaded by their specific pipeline class if you already know it. For example, to load a Stable Diffusion model, use the [`StableDiffusionPipeline`] class.
+Every model has a specific pipeline subclass that inherits from [`DiffusionPipeline`]. A subclass usually has a narrow focus and are task-specific. See the table below for an example.
-```python
-from diffusers import StableDiffusionPipeline
-
-pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-```
+| pipeline subclass | task |
+|---|---|
+| [`QwenImagePipeline`] | text-to-image |
+| [`QwenImageImg2ImgPipeline`] | image-to-image |
+| [`QwenImageInpaintPipeline`] | inpaint |
-This same checkpoint may also be used for another task like image-to-image. To differentiate what task you want to use the checkpoint for, you have to use the corresponding task-specific pipeline class. For example, to use the same checkpoint for image-to-image, use the [`StableDiffusionImg2ImgPipeline`] class.
+You could use the subclass directly by passing a model id to [`~QwenImagePipeline.from_pretrained`].
```py
-from diffusers import StableDiffusionImg2ImgPipeline
-
-pipeline = StableDiffusionImg2ImgPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-```
-
-
-
-
-Use the Space below to gauge a pipeline's memory requirements before you download and load it to see if it runs on your hardware.
-
-
-
-
-
-
-
-
-### Specifying Component-Specific Data Types
-
-You can customize the data types for individual sub-models by passing a dictionary to the `torch_dtype` parameter. This allows you to load different components of a pipeline in different floating point precisions. For instance, if you want to load the transformer with `torch.bfloat16` and all other components with `torch.float16`, you can pass a dictionary mapping:
-
-```python
-from diffusers import HunyuanVideoPipeline
import torch
+from diffusers import QwenImagePipeline
-pipe = HunyuanVideoPipeline.from_pretrained(
- "hunyuanvideo-community/HunyuanVideo",
- torch_dtype={'transformer': torch.bfloat16, 'default': torch.float16},
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
)
-print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
```
-If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.
+> [!TIP]
+> Refer to the [Single file format](./other-formats#single-file-format) docs to learn how to load single file models.
+
+### Local pipelines
-### Local pipeline
+Pipelines can also be run locally. Use [`~huggingface_hub.snapshot_download`] to download a model repository.
-To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.
+```py
+from huggingface_hub import snapshot_download
-```bash
-git-lfs install
-git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5
+snapshot_download(repo_id="Qwen/Qwen-Image")
```
-This creates a local folder, ./stable-diffusion-v1-5, on your disk and you should pass its path to [`~DiffusionPipeline.from_pretrained`].
+The model is downloaded to your [cache](../installation#cache). Pass the folder path to [`~QwenImagePipeline.from_pretrained`] to load it.
-```python
-from diffusers import DiffusionPipeline
+```py
+import torch
+from diffusers import QwenImagePipeline
-stable_diffusion = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", use_safetensors=True)
+pipeline = QwenImagePipeline.from_pretrained(
+ "path/to/your/cache", torch_dtype=torch.bfloat16, device_map="cuda"
+)
```
-The [`~DiffusionPipeline.from_pretrained`] method won't download files from the Hub when it detects a local path, but this also means it won't download and cache the latest changes to a checkpoint.
+The [`~QwenImagePipeline.from_pretrained`] method won't download files from the Hub when it detects a local path. But this also means it won't download and cache any updates that have been made to the model either.
-## Customize a pipeline
+## Pipeline data types
-You can customize a pipeline by loading different components into it. This is important because you can:
+Use the `torch_dtype` argument in [`~DiffusionPipeline.from_pretrained`] to load a model with a specific data type. This allows you to load different models in different precisions. For example, loading a large transformer model in half-precision reduces the memory required.
-- change to a scheduler with faster generation speed or higher generation quality depending on your needs (call the `scheduler.compatibles` method on your pipeline to see compatible schedulers)
-- change a default pipeline component to a newer and better performing one
-
-For example, let's customize the default [stabilityai/stable-diffusion-xl-base-1.0](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0) checkpoint with:
-
-- The [`HeunDiscreteScheduler`] to generate higher quality images at the expense of slower generation speed. You must pass the `subfolder="scheduler"` parameter in [`~HeunDiscreteScheduler.from_pretrained`] to load the scheduler configuration into the correct [subfolder](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main/scheduler) of the pipeline repository.
-- A more stable VAE that runs in fp16.
+Pass the data type for each model as a dictionary to `torch_dtype`. Use the `default` key to set the default data type. If a model isn't in the dictionary and `default` isn't provided, it is loaded in full precision (`torch.float32`).
```py
-from diffusers import StableDiffusionXLPipeline, HeunDiscreteScheduler, AutoencoderKL
import torch
+from diffusers import QwenImagePipeline
-scheduler = HeunDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
-vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype={"transformer": torch.bfloat16, "default": torch.float16},
+)
+print(pipeline.transformer.dtype, pipeline.vae.dtype)
```
-Now pass the new scheduler and VAE to the [`StableDiffusionXLPipeline`].
+You don't need to use a dictionary if you're loading all the models in the same data type.
```py
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- scheduler=scheduler,
- vae=vae,
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True
-).to("cuda")
-```
-
-## Reuse a pipeline
-
-When you load multiple pipelines that share the same model components, it makes sense to reuse the shared components instead of reloading everything into memory again, especially if your hardware is memory-constrained. For example:
-
-1. You generated an image with the [`StableDiffusionPipeline`] but you want to improve its quality with the [`StableDiffusionSAGPipeline`]. Both of these pipelines share the same pretrained model, so it'd be a waste of memory to load the same model twice.
-2. You want to add a model component, like a [`MotionAdapter`](../api/pipelines/animatediff#animatediffpipeline), to [`AnimateDiffPipeline`] which was instantiated from an existing [`StableDiffusionPipeline`]. Again, both pipelines share the same pretrained model, so it'd be a waste of memory to load an entirely new pipeline again.
-
-With the [`DiffusionPipeline.from_pipe`] API, you can switch between multiple pipelines to take advantage of their different features without increasing memory-usage. It is similar to turning on and off a feature in your pipeline.
-
-> [!TIP]
-> To switch between tasks (rather than features), use the [`~DiffusionPipeline.from_pipe`] method with the [AutoPipeline](../api/pipelines/auto_pipeline) class, which automatically identifies the pipeline class based on the task (learn more in the [AutoPipeline](../tutorials/autopipeline) tutorial).
-
-Let's start with a [`StableDiffusionPipeline`] and then reuse the loaded model components to create a [`StableDiffusionSAGPipeline`] to increase generation quality. You'll use the [`StableDiffusionPipeline`] with an [IP-Adapter](./ip_adapter) to generate a bear eating pizza.
-
-```python
-from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline
import torch
-import gc
-from diffusers.utils import load_image
-from accelerate.utils import compute_module_sizes
-
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
-
-pipe_sd = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", torch_dtype=torch.float16)
-pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
-pipe_sd.set_ip_adapter_scale(0.6)
-pipe_sd.to("cuda")
-
-generator = torch.Generator(device="cpu").manual_seed(33)
-out_sd = pipe_sd(
- prompt="bear eats pizza",
- negative_prompt="wrong white balance, dark, sketches,worst quality,low quality",
- ip_adapter_image=image,
- num_inference_steps=50,
- generator=generator,
-).images[0]
-out_sd
-```
+from diffusers import QwenImagePipeline
-
-
-
-
-For reference, you can check how much memory this process consumed.
-
-```python
-def bytes_to_giga_bytes(bytes):
- return bytes / 1024 / 1024 / 1024
-print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")
-"Max memory allocated: 4.406213283538818 GB"
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16
+)
+print(pipeline.transformer.dtype, pipeline.vae.dtype)
```
-Now, reuse the same pipeline components from [`StableDiffusionPipeline`] in [`StableDiffusionSAGPipeline`] with the [`~DiffusionPipeline.from_pipe`] method.
-
-> [!WARNING]
-> Some pipeline methods may not function properly on new pipelines created with [`~DiffusionPipeline.from_pipe`]. For instance, the [`~DiffusionPipeline.enable_model_cpu_offload`] method installs hooks on the model components based on a unique offloading sequence for each pipeline. If the models are executed in a different order in the new pipeline, the CPU offloading may not work correctly.
->
-> To ensure everything works as expected, we recommend re-applying a pipeline method on a new pipeline created with [`~DiffusionPipeline.from_pipe`].
+## Device placement
-```python
-pipe_sag = StableDiffusionSAGPipeline.from_pipe(
- pipe_sd
-)
+The `device_map` argument determines individual model or pipeline placement on an accelerator like a GPU. It is especially helpful when there are multiple GPUs.
-generator = torch.Generator(device="cpu").manual_seed(33)
-out_sag = pipe_sag(
- prompt="bear eats pizza",
- negative_prompt="wrong white balance, dark, sketches,worst quality,low quality",
- ip_adapter_image=image,
- num_inference_steps=50,
- generator=generator,
- guidance_scale=1.0,
- sag_scale=0.75
-).images[0]
-out_sag
-```
+A pipeline supports two options for `device_map`, `"cuda"` and `"balanced"`. Refer to the table below to compare the placement strategies.
-
-
-
+| parameter | description |
+|---|---|
+| `"cuda"` | places pipeline on a supported accelerator device like CUDA |
+| `"balanced"` | evenly distributes pipeline on all GPUs |
-If you check the memory usage, you'll see it remains the same as before because [`StableDiffusionPipeline`] and [`StableDiffusionSAGPipeline`] are sharing the same pipeline components. This allows you to use them interchangeably without any additional memory overhead.
+Use the `max_memory` argument in [`~DiffusionPipeline.from_pretrained`] to allocate a maximum amount of memory to use on each device. By default, Diffusers uses the maximum amount available.
```py
-print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")
-"Max memory allocated: 4.406213283538818 GB"
-```
-
-Let's animate the image with the [`AnimateDiffPipeline`] and also add a [`MotionAdapter`] module to the pipeline. For the [`AnimateDiffPipeline`], you need to unload the IP-Adapter first and reload it *after* you've created your new pipeline (this only applies to the [`AnimateDiffPipeline`]).
+import torch
+from diffusers import DiffusionPipeline
-```py
-from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
-from diffusers.utils import export_to_gif
-
-pipe_sag.unload_ip_adapter()
-adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
-
-pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
-pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
-# load IP-Adapter and LoRA weights again
-pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
-pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
-pipe_animate.to("cuda")
-
-generator = torch.Generator(device="cpu").manual_seed(33)
-pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
-out = pipe_animate(
- prompt="bear eats pizza",
- num_frames=16,
- num_inference_steps=50,
- ip_adapter_image=image,
- generator=generator,
-).frames[0]
-export_to_gif(out, "out_animate.gif")
+max_memory = {0: "16GB", 1: "16GB"}
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda",
+)
```
-
-
-
-
-The [`AnimateDiffPipeline`] is more memory-intensive and consumes 15GB of memory (see the [Memory-usage of from_pipe](#memory-usage-of-from_pipe) section to learn what this means for your memory-usage).
+The `hf_device_map` attribute allows you to access and view the `device_map`.
```py
-print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")
-"Max memory allocated: 15.178664207458496 GB"
+print(pipeline.hf_device_map)
+# {'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
```
-### Modify from_pipe components
-
-Pipelines loaded with [`~DiffusionPipeline.from_pipe`] can be customized with different model components or methods. However, whenever you modify the *state* of the model components, it affects all the other pipelines that share the same components. For example, if you call [`~diffusers.loaders.IPAdapterMixin.unload_ip_adapter`] on the [`StableDiffusionSAGPipeline`], you won't be able to use IP-Adapter with the [`StableDiffusionPipeline`] because it's been removed from their shared components.
+Reset a pipeline's `device_map` with the [`~DiffusionPipeline.reset_device_map`] method. This is necessary if you want to use methods such as `.to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`].
```py
-pipe.sag_unload_ip_adapter()
-
-generator = torch.Generator(device="cpu").manual_seed(33)
-out_sd = pipe_sd(
- prompt="bear eats pizza",
- negative_prompt="wrong white balance, dark, sketches,worst quality,low quality",
- ip_adapter_image=image,
- num_inference_steps=50,
- generator=generator,
-).images[0]
-"AttributeError: 'NoneType' object has no attribute 'image_projection_layers'"
+pipeline.reset_device_map()
```
-### Memory usage of from_pipe
-
-The memory requirement of loading multiple pipelines with [`~DiffusionPipeline.from_pipe`] is determined by the pipeline with the highest memory-usage regardless of the number of pipelines you create.
+## Parallel loading
-| Pipeline | Memory usage (GB) |
-|---|---|
-| StableDiffusionPipeline | 4.400 |
-| StableDiffusionSAGPipeline | 4.400 |
-| AnimateDiffPipeline | 15.178 |
+Large models are often [sharded](../training/distributed_inference#model-sharding) into smaller files so that they are easier to load. Diffusers supports loading shards in parallel to speed up the loading process.
-The [`AnimateDiffPipeline`] has the highest memory requirement, so the *total memory-usage* is based only on the [`AnimateDiffPipeline`]. Your memory-usage will not increase if you create additional pipelines as long as their memory requirements doesn't exceed that of the [`AnimateDiffPipeline`]. Each pipeline can be used interchangeably without any additional memory overhead.
-
-## Safety checker
+Set `HF_ENABLE_PARALLEL_LOADING` to `"YES"` to enable parallel loading of shards.
-Diffusers implements a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) for Stable Diffusion models which can generate harmful content. The safety checker screens the generated output against known hardcoded not-safe-for-work (NSFW) content. If for whatever reason you'd like to disable the safety checker, pass `safety_checker=None` to the [`~DiffusionPipeline.from_pretrained`] method.
+The `device_map` argument should be set to `"cuda"` to pre-allocate a large chunk of memory based on the model size. This substantially reduces model load time because warming up the memory allocator now avoids many smaller calls to the allocator later.
-```python
+```py
+import os
+import torch
from diffusers import DiffusionPipeline
-pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, use_safetensors=True)
-"""
-You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide by the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend keeping the safety filter enabled in all public-facing circumstances, disabling it only for use cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
-"""
-```
-
-## Checkpoint variants
-
-A checkpoint variant is usually a checkpoint whose weights are:
-
-- Stored in a different floating point type, such as [torch.float16](https://pytorch.org/docs/stable/tensors.html#data-types), because it only requires half the bandwidth and storage to download. You can't use this variant if you're continuing training or using a CPU.
-- Non-exponential mean averaged (EMA) weights which shouldn't be used for inference. You should use this variant to continue finetuning a model.
-
-> [!TIP]
-> When the checkpoints have identical model structures, but they were trained on different datasets and with a different training setup, they should be stored in separate repositories. For example, [stabilityai/stable-diffusion-2](https://hf.co/stabilityai/stable-diffusion-2) and [stabilityai/stable-diffusion-2-1](https://hf.co/stabilityai/stable-diffusion-2-1) are stored in separate repositories.
-
-Otherwise, a variant is **identical** to the original checkpoint. They have exactly the same serialization format (like [safetensors](./using_safetensors)), model structure, and their weights have identical tensor shapes.
+os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
-| **checkpoint type** | **weight name** | **argument for loading weights** |
-|---------------------|---------------------------------------------|----------------------------------|
-| original | diffusion_pytorch_model.safetensors | |
-| floating point | diffusion_pytorch_model.fp16.safetensors | `variant`, `torch_dtype` |
-| non-EMA | diffusion_pytorch_model.non_ema.safetensors | `variant` |
-
-There are two important arguments for loading variants:
-
-- `torch_dtype` specifies the floating point precision of the loaded checkpoint. For example, if you want to save bandwidth by loading a fp16 variant, you should set `variant="fp16"` and `torch_dtype=torch.float16` to *convert the weights* to fp16. Otherwise, the fp16 weights are converted to the default fp32 precision.
+pipeline = DiffusionPipeline.from_pretrained(
+ "Wan-AI/Wan2.2-I2V-A14B-Diffusers", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+```
- If you only set `torch_dtype=torch.float16`, the default fp32 weights are downloaded first and then converted to fp16.
+## Replacing models in a pipeline
-- `variant` specifies which files should be loaded from the repository. For example, if you want to load a non-EMA variant of a UNet from [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet), set `variant="non_ema"` to download the `non_ema` file.
+[`DiffusionPipeline`] is flexible and accommodates loading different models or schedulers. You can experiment with different schedulers to optimize for generation speed or quality, and you can replace models with more performant ones.
-
-
+The example below uses a more stable VAE version.
```py
-from diffusers import DiffusionPipeline
import torch
+from diffusers import DiffusionPipeline, AutoModel
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16, use_safetensors=True
+vae = AutoModel.from_pretrained(
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
-```
-
-
-
-```py
pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="non_ema", use_safetensors=True
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ vae=vae,
+ torch_dtype=torch.float16,
+ device_map="cuda"
)
```
-
-
-
-Use the `variant` parameter in the [`DiffusionPipeline.save_pretrained`] method to save a checkpoint as a different floating point type or as a non-EMA variant. You should try save a variant to the same folder as the original checkpoint, so you have the option of loading both from the same folder.
+## Reusing models in multiple pipelines
-
-
+When working with multiple pipelines that use the same model, the [`~DiffusionPipeline.from_pipe`] method enables reusing a model instead of reloading it each time. This allows you to use multiple pipelines without increasing memory usage.
-```python
-from diffusers import DiffusionPipeline
+Memory usage is determined by the pipeline with the highest memory requirement regardless of the number of pipelines.
-pipeline.save_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", variant="fp16")
-```
+The example below loads a pipeline and then loads a second pipeline with [`~DiffusionPipeline.from_pipe`] to use [perturbed-attention guidance (PAG)](../api/pipelines/pag) to improve generation quality.
-
-
+> [!WARNING]
+> Use [`AutoPipelineForText2Image`] because [`DiffusionPipeline`] doesn't support PAG. Refer to the [AutoPipeline](../tutorials/autopipeline) docs to learn more.
```py
-pipeline.save_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", variant="non_ema")
-```
-
-
-
-
-If you don't save the variant to an existing folder, you must specify the `variant` argument otherwise it'll throw an `Exception` because it can't find the original checkpoint.
+import torch
+from diffusers import AutoPipelineForText2Image
-```python
-# 👎 this won't work
-pipeline = DiffusionPipeline.from_pretrained(
- "./stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-)
-# 👍 this works
-pipeline = DiffusionPipeline.from_pretrained(
- "./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16, use_safetensors=True
+pipeline_sdxl = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda"
)
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+image = pipeline_sdxl(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+# Max memory reserved: 10.47 GB
```
-## DiffusionPipeline explained
-
-As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things:
-
-- Download the latest version of the folder structure required for inference and cache it. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] reuses the cache and won't redownload the files.
-- Load the cached weights into the correct pipeline [class](../api/pipelines/overview#diffusers-summary) - retrieved from the `model_index.json` file - and return an instance of it.
-
-The pipelines' underlying folder structure corresponds directly with their class instances. For example, the [`StableDiffusionPipeline`] corresponds to the folder structure in [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5).
+Set `enable_pag=True` in the second pipeline to enable PAG. The second pipeline uses the same amount of memory because it shares model weights with the first one.
-```python
-from diffusers import DiffusionPipeline
-
-repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-pipeline = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
-print(pipeline)
+```py
+pipeline = AutoPipelineForText2Image.from_pipe(
+ pipeline_sdxl, enable_pag=True
+)
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+image = pipeline(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+# Max memory reserved: 10.47 GB
```
-You'll see pipeline is an instance of [`StableDiffusionPipeline`], which consists of seven components:
-
-- `"feature_extractor"`: a [`~transformers.CLIPImageProcessor`] from 🤗 Transformers.
-- `"safety_checker"`: a [component](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32) for screening against harmful content.
-- `"scheduler"`: an instance of [`PNDMScheduler`].
-- `"text_encoder"`: a [`~transformers.CLIPTextModel`] from 🤗 Transformers.
-- `"tokenizer"`: a [`~transformers.CLIPTokenizer`] from 🤗 Transformers.
-- `"unet"`: an instance of [`UNet2DConditionModel`].
-- `"vae"`: an instance of [`AutoencoderKL`].
-
-```json
-StableDiffusionPipeline {
- "feature_extractor": [
- "transformers",
- "CLIPImageProcessor"
- ],
- "safety_checker": [
- "stable_diffusion",
- "StableDiffusionSafetyChecker"
- ],
- "scheduler": [
- "diffusers",
- "PNDMScheduler"
- ],
- "text_encoder": [
- "transformers",
- "CLIPTextModel"
- ],
- "tokenizer": [
- "transformers",
- "CLIPTokenizer"
- ],
- "unet": [
- "diffusers",
- "UNet2DConditionModel"
- ],
- "vae": [
- "diffusers",
- "AutoencoderKL"
- ]
-}
-```
+> [!WARNING]
+> Pipelines created by [`~DiffusionPipeline.from_pipe`] share the same models and *state*. Modifying the state of a model in one pipeline affects all the other pipelines that share the same model.
-Compare the components of the pipeline instance to the [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main) folder structure, and you'll see there is a separate folder for each of the components in the repository:
+Some methods may not work correctly on pipelines created with [`~DiffusionPipeline.from_pipe`]. For example, [`~DiffusionPipeline.enable_model_cpu_offload`] relies on a unique model execution order, which may differ in the new pipeline. To ensure proper functionality, reapply these methods on the new pipeline.
-```
-.
-├── feature_extractor
-│ └── preprocessor_config.json
-├── model_index.json
-├── safety_checker
-│ ├── config.json
-| ├── model.fp16.safetensors
-│ ├── model.safetensors
-│ ├── pytorch_model.bin
-| └── pytorch_model.fp16.bin
-├── scheduler
-│ └── scheduler_config.json
-├── text_encoder
-│ ├── config.json
-| ├── model.fp16.safetensors
-│ ├── model.safetensors
-│ |── pytorch_model.bin
-| └── pytorch_model.fp16.bin
-├── tokenizer
-│ ├── merges.txt
-│ ├── special_tokens_map.json
-│ ├── tokenizer_config.json
-│ └── vocab.json
-├── unet
-│ ├── config.json
-│ ├── diffusion_pytorch_model.bin
-| |── diffusion_pytorch_model.fp16.bin
-│ |── diffusion_pytorch_model.f16.safetensors
-│ |── diffusion_pytorch_model.non_ema.bin
-│ |── diffusion_pytorch_model.non_ema.safetensors
-│ └── diffusion_pytorch_model.safetensors
-|── vae
-. ├── config.json
-. ├── diffusion_pytorch_model.bin
- ├── diffusion_pytorch_model.fp16.bin
- ├── diffusion_pytorch_model.fp16.safetensors
- └── diffusion_pytorch_model.safetensors
-```
+## Safety checker
-You can access each of the components of the pipeline as an attribute to view its configuration:
+Diffusers provides a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) for older Stable Diffusion models to prevent generating harmful content. It screens the generated output against a set of hardcoded harmful concepts.
+
+If you want to disable the safety checker, pass `safety_checker=None` in [`~DiffusionPipeline.from_pretrained`] as shown below.
```py
-pipeline.tokenizer
-CLIPTokenizer(
- name_or_path="/root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/tokenizer",
- vocab_size=49408,
- model_max_length=77,
- is_fast=False,
- padding_side="right",
- truncation_side="right",
- special_tokens={
- "bos_token": AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
- "eos_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
- "unk_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
- "pad_token": "<|endoftext|>",
- },
- clean_up_tokenization_spaces=True
-)
-```
+from diffusers import DiffusionPipeline
-Every pipeline expects a [`model_index.json`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json) file that tells the [`DiffusionPipeline`]:
-
-- which pipeline class to load from `_class_name`
-- which version of 🧨 Diffusers was used to create the model in `_diffusers_version`
-- what components from which library are stored in the subfolders (`name` corresponds to the component and subfolder name, `library` corresponds to the name of the library to load the class from, and `class` corresponds to the class name)
-
-```json
-{
- "_class_name": "StableDiffusionPipeline",
- "_diffusers_version": "0.6.0",
- "feature_extractor": [
- "transformers",
- "CLIPImageProcessor"
- ],
- "safety_checker": [
- "stable_diffusion",
- "StableDiffusionSafetyChecker"
- ],
- "scheduler": [
- "diffusers",
- "PNDMScheduler"
- ],
- "text_encoder": [
- "transformers",
- "CLIPTextModel"
- ],
- "tokenizer": [
- "transformers",
- "CLIPTokenizer"
- ],
- "unet": [
- "diffusers",
- "UNet2DConditionModel"
- ],
- "vae": [
- "diffusers",
- "AutoencoderKL"
- ]
-}
-```
+pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None
+)
+"""
+You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide by the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend keeping the safety filter enabled in all public-facing circumstances, disabling it only for use cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
+"""
+```
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md
deleted file mode 100644
index e16c1322e5d1..000000000000
--- a/docs/source/en/using-diffusers/loading_adapters.md
+++ /dev/null
@@ -1,363 +0,0 @@
-
-
-# Load adapters
-
-[[open-in-colab]]
-
-There are several [training](../training/overview) techniques for personalizing diffusion models to generate images of a specific subject or images in certain styles. Each of these training methods produces a different type of adapter. Some of the adapters generate an entirely new model, while other adapters only modify a smaller set of embeddings or weights. This means the loading process for each adapter is also different.
-
-This guide will show you how to load DreamBooth, textual inversion, and LoRA weights.
-
-
-
-Feel free to browse the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer), [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), and the [Diffusers Models Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) for checkpoints and embeddings to use.
-
-
-
-## DreamBooth
-
-[DreamBooth](https://dreambooth.github.io/) finetunes an *entire diffusion model* on just several images of a subject to generate images of that subject in new styles and settings. This method works by using a special word in the prompt that the model learns to associate with the subject image. Of all the training methods, DreamBooth produces the largest file size (usually a few GBs) because it is a full checkpoint model.
-
-Let's load the [herge_style](https://huggingface.co/sd-dreambooth-library/herge-style) checkpoint, which is trained on just 10 images drawn by Hergé, to generate images in that style. For it to work, you need to include the special word `herge_style` in your prompt to trigger the checkpoint:
-
-```py
-from diffusers import AutoPipelineForText2Image
-import torch
-
-pipeline = AutoPipelineForText2Image.from_pretrained("sd-dreambooth-library/herge-style", torch_dtype=torch.float16).to("cuda")
-prompt = "A cute herge_style brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
-image = pipeline(prompt).images[0]
-image
-```
-
-
-
-
-
-## Textual inversion
-
-[Textual inversion](https://textual-inversion.github.io/) is very similar to DreamBooth and it can also personalize a diffusion model to generate certain concepts (styles, objects) from just a few images. This method works by training and finding new embeddings that represent the images you provide with a special word in the prompt. As a result, the diffusion model weights stay the same and the training process produces a relatively tiny (a few KBs) file.
-
-Because textual inversion creates embeddings, it cannot be used on its own like DreamBooth and requires another model.
-
-```py
-from diffusers import AutoPipelineForText2Image
-import torch
-
-pipeline = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
-```
-
-Now you can load the textual inversion embeddings with the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] method and generate some images. Let's load the [sd-concepts-library/gta5-artwork](https://huggingface.co/sd-concepts-library/gta5-artwork) embeddings and you'll need to include the special word `` in your prompt to trigger it:
-
-```py
-pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
-prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, style"
-image = pipeline(prompt).images[0]
-image
-```
-
-
-
-
-
-Textual inversion can also be trained on undesirable things to create *negative embeddings* to discourage a model from generating images with those undesirable things like blurry images or extra fingers on a hand. This can be an easy way to quickly improve your prompt. You'll also load the embeddings with [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`], but this time, you'll need two more parameters:
-
-- `weight_name`: specifies the weight file to load if the file was saved in the 🤗 Diffusers format with a specific name or if the file is stored in the A1111 format
-- `token`: specifies the special word to use in the prompt to trigger the embeddings
-
-Let's load the [sayakpaul/EasyNegative-test](https://huggingface.co/sayakpaul/EasyNegative-test) embeddings:
-
-```py
-pipeline.load_textual_inversion(
- "sayakpaul/EasyNegative-test", weight_name="EasyNegative.safetensors", token="EasyNegative"
-)
-```
-
-Now you can use the `token` to generate an image with the negative embeddings:
-
-```py
-prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, EasyNegative"
-negative_prompt = "EasyNegative"
-
-image = pipeline(prompt, negative_prompt=negative_prompt, num_inference_steps=50).images[0]
-image
-```
-
-
-
-
-
-## LoRA
-
-[Low-Rank Adaptation (LoRA)](https://huggingface.co/papers/2106.09685) is a popular training technique because it is fast and generates smaller file sizes (a couple hundred MBs). Like the other methods in this guide, LoRA can train a model to learn new styles from just a few images. It works by inserting new weights into the diffusion model and then only the new weights are trained instead of the entire model. This makes LoRAs faster to train and easier to store.
-
-
-
-LoRA is a very general training technique that can be used with other training methods. For example, it is common to train a model with DreamBooth and LoRA. It is also increasingly common to load and merge multiple LoRAs to create new and unique images. You can learn more about it in the in-depth [Merge LoRAs](merge_loras) guide since merging is outside the scope of this loading guide.
-
-
-
-LoRAs also need to be used with another model:
-
-```py
-from diffusers import AutoPipelineForText2Image
-import torch
-
-pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-```
-
-Then use the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method to load the [ostris/super-cereal-sdxl-lora](https://huggingface.co/ostris/super-cereal-sdxl-lora) weights and specify the weights filename from the repository:
-
-```py
-pipeline.load_lora_weights("ostris/super-cereal-sdxl-lora", weight_name="cereal_box_sdxl_v1.safetensors")
-prompt = "bears, pizza bites"
-image = pipeline(prompt).images[0]
-image
-```
-
-
-
-
-
-The [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method loads LoRA weights into both the UNet and text encoder. It is the preferred way for loading LoRAs because it can handle cases where:
-
-- the LoRA weights don't have separate identifiers for the UNet and text encoder
-- the LoRA weights have separate identifiers for the UNet and text encoder
-
-To directly load (and save) a LoRA adapter at the *model-level*, use [`~PeftAdapterMixin.load_lora_adapter`], which builds and prepares the necessary model configuration for the adapter. Like [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`PeftAdapterMixin.load_lora_adapter`] can load LoRAs for both the UNet and text encoder. For example, if you're loading a LoRA for the UNet, [`PeftAdapterMixin.load_lora_adapter`] ignores the keys for the text encoder.
-
-Use the `weight_name` parameter to specify the specific weight file and the `prefix` parameter to filter for the appropriate state dicts (`"unet"` in this case) to load.
-
-```py
-from diffusers import AutoPipelineForText2Image
-import torch
-
-pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-pipeline.unet.load_lora_adapter("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", prefix="unet")
-
-# use cnmt in the prompt to trigger the LoRA
-prompt = "A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration"
-image = pipeline(prompt).images[0]
-image
-```
-
-
-
-
-
-Save an adapter with [`~PeftAdapterMixin.save_lora_adapter`].
-
-To unload the LoRA weights, use the [`~loaders.StableDiffusionLoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
-
-```py
-pipeline.unload_lora_weights()
-```
-
-### Adjust LoRA weight scale
-
-For both [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
-
-For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by.
-```python
-pipe = ... # create pipeline
-pipe.load_lora_weights(..., adapter_name="my_adapter")
-scales = {
- "text_encoder": 0.5,
- "text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
- "unet": {
- "down": 0.9, # all transformers in the down-part will use scale 0.9
- # "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
- "up": {
- "block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
- "block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
- }
- }
-}
-pipe.set_adapters("my_adapter", scales)
-```
-
-This also works with multiple adapters - see [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength) for how to do it.
-
-
-
-Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRA has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0.
-
-
-
-### Kohya and TheLastBen
-
-Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
-
-
-
-
-To load a Kohya LoRA, let's download the [Blueprintify SD XL 1.0](https://civitai.com/models/150986/blueprintify-sd-xl-10) checkpoint from [Civitai](https://civitai.com/) as an example:
-
-```sh
-!wget https://civitai.com/api/download/models/168776 -O blueprintify-sd-xl-10.safetensors
-```
-
-Load the LoRA checkpoint with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method, and specify the filename in the `weight_name` parameter:
-
-```py
-from diffusers import AutoPipelineForText2Image
-import torch
-
-pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-pipeline.load_lora_weights("path/to/weights", weight_name="blueprintify-sd-xl-10.safetensors")
-```
-
-Generate an image:
-
-```py
-# use bl3uprint in the prompt to trigger the LoRA
-prompt = "bl3uprint, a highly detailed blueprint of the eiffel tower, explaining how to build all parts, many txt, blueprint grid backdrop"
-image = pipeline(prompt).images[0]
-image
-```
-
-
-
-Some limitations of using Kohya LoRAs with 🤗 Diffusers include:
-
-- Images may not look like those generated by UIs - like ComfyUI - for multiple reasons, which are explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
-- [LyCORIS checkpoints](https://github.com/KohakuBlueleaf/LyCORIS) aren't fully supported. The [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method loads LyCORIS checkpoints with LoRA and LoCon modules, but Hada and LoKR are not supported.
-
-
-
-
-
-
-Loading a checkpoint from TheLastBen is very similar. For example, to load the [TheLastBen/William_Eggleston_Style_SDXL](https://huggingface.co/TheLastBen/William_Eggleston_Style_SDXL) checkpoint:
-
-```py
-from diffusers import AutoPipelineForText2Image
-import torch
-
-pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-pipeline.load_lora_weights("TheLastBen/William_Eggleston_Style_SDXL", weight_name="wegg.safetensors")
-
-# use by william eggleston in the prompt to trigger the LoRA
-prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, beautiful"
-image = pipeline(prompt=prompt).images[0]
-image
-```
-
-
-
-
-## IP-Adapter
-
-[IP-Adapter](https://ip-adapter.github.io/) is a lightweight adapter that enables image prompting for any diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.
-
-You can learn more about how to use IP-Adapter for different tasks and specific use cases in the [IP-Adapter](../using-diffusers/ip_adapter) guide.
-
-> [!TIP]
-> Diffusers currently only supports IP-Adapter for some of the most popular pipelines. Feel free to open a feature request if you have a cool use case and want to integrate IP-Adapter with an unsupported pipeline!
-> Official IP-Adapter checkpoints are available from [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter).
-
-To start, load a Stable Diffusion checkpoint.
-
-```py
-from diffusers import AutoPipelineForText2Image
-import torch
-from diffusers.utils import load_image
-
-pipeline = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
-```
-
-Then load the IP-Adapter weights and add it to the pipeline with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method.
-
-```py
-pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
-```
-
-Once loaded, you can use the pipeline with an image and text prompt to guide the image generation process.
-
-```py
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
-generator = torch.Generator(device="cpu").manual_seed(33)
-images = pipeline(
- prompt='best quality, high quality, wearing sunglasses',
- ip_adapter_image=image,
- negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
- num_inference_steps=50,
- generator=generator,
-).images[0]
-images
-```
-
-
-
-
-
-### IP-Adapter Plus
-
-IP-Adapter relies on an image encoder to generate image features. If the IP-Adapter repository contains an `image_encoder` subfolder, the image encoder is automatically loaded and registered to the pipeline. Otherwise, you'll need to explicitly load the image encoder with a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to the pipeline.
-
-This is the case for *IP-Adapter Plus* checkpoints which use the ViT-H image encoder.
-
-```py
-from transformers import CLIPVisionModelWithProjection
-
-image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- "h94/IP-Adapter",
- subfolder="models/image_encoder",
- torch_dtype=torch.float16
-)
-
-pipeline = AutoPipelineForText2Image.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- image_encoder=image_encoder,
- torch_dtype=torch.float16
-).to("cuda")
-
-pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors")
-```
-
-### IP-Adapter Face ID models
-
-The IP-Adapter FaceID models are experimental IP Adapters that use image embeddings generated by `insightface` instead of CLIP image embeddings. Some of these models also use LoRA to improve ID consistency.
-You need to install `insightface` and all its requirements to use these models.
-
-
-As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and are not intended for commercial use.
-
-
-```py
-pipeline = AutoPipelineForText2Image.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16
-).to("cuda")
-
-pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None)
-```
-
-If you want to use one of the two IP-Adapter FaceID Plus models, you must also load the CLIP image encoder, as this models use both `insightface` and CLIP image embeddings to achieve better photorealism.
-
-```py
-from transformers import CLIPVisionModelWithProjection
-
-image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
- torch_dtype=torch.float16,
-)
-
-pipeline = AutoPipelineForText2Image.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- image_encoder=image_encoder,
- torch_dtype=torch.float16
-).to("cuda")
-
-pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plus_sd15.bin")
-```
diff --git a/docs/source/en/using-diffusers/marigold_usage.md b/docs/source/en/using-diffusers/marigold_usage.md
index b8e9a5838e8d..f66e47bada09 100644
--- a/docs/source/en/using-diffusers/marigold_usage.md
+++ b/docs/source/en/using-diffusers/marigold_usage.md
@@ -288,7 +288,7 @@ Speeding them up can be achieved by using a more efficient attention processor:
depth = pipe(image, num_inference_steps=1)
```
-Finally, as suggested in [Optimizations](../optimization/torch2.0#torch.compile), enabling `torch.compile` can further enhance performance depending on
+Finally, as suggested in [Optimizations](../optimization/fp16#torchcompile), enabling `torch.compile` can further enhance performance depending on
the target hardware.
However, compilation incurs a significant overhead during the first pipeline invocation, making it beneficial only when
the same pipeline instance is called repeatedly, such as within a loop.
diff --git a/docs/source/en/using-diffusers/merge_loras.md b/docs/source/en/using-diffusers/merge_loras.md
deleted file mode 100644
index eb7d7d57ef3d..000000000000
--- a/docs/source/en/using-diffusers/merge_loras.md
+++ /dev/null
@@ -1,266 +0,0 @@
-
-
-# Merge LoRAs
-
-It can be fun and creative to use multiple [LoRAs]((https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora)) together to generate something entirely new and unique. This works by merging multiple LoRA weights together to produce images that are a blend of different styles. Diffusers provides a few methods to merge LoRAs depending on *how* you want to merge their weights, which can affect image quality.
-
-This guide will show you how to merge LoRAs using the [`~loaders.PeftAdapterMixin.set_adapters`] and [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
-
-For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style](https://huggingface.co/KappaNeuro/studio-ghibli-style) and [Norod78/sdxl-chalkboarddrawing-lora](https://huggingface.co/Norod78/sdxl-chalkboarddrawing-lora) LoRAs with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
-
-```py
-from diffusers import DiffusionPipeline
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
-pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
-```
-
-## set_adapters
-
-The [`~loaders.PeftAdapterMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
-
-```py
-pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
-
-generator = torch.manual_seed(0)
-prompt = "A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai"
-image = pipeline(prompt, generator=generator, cross_attention_kwargs={"scale": 1.0}).images[0]
-image
-```
-
-
-
-
-
-## add_weighted_adapter
-
-> [!WARNING]
-> This is an experimental method that adds PEFTs [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
-
-The [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
-
-```bash
-pip install -U diffusers peft
-```
-
-There are three steps to merge LoRAs with the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method:
-
-1. Create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the underlying model and LoRA checkpoint.
-2. Load a base UNet model and the LoRA adapters.
-3. Merge the adapters using the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method and the merging method of your choice.
-
-Let's dive deeper into what these steps entail.
-
-1. Load a UNet that corresponds to the UNet in the LoRA checkpoint. In this case, both LoRAs use the SDXL UNet as their base model.
-
-```python
-from diffusers import UNet2DConditionModel
-import torch
-
-unet = UNet2DConditionModel.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16,
- use_safetensors=True,
- variant="fp16",
- subfolder="unet",
-).to("cuda")
-```
-
-Load the SDXL pipeline and the LoRA checkpoints, starting with the [ostris/ikea-instructions-lora-sdxl](https://huggingface.co/ostris/ikea-instructions-lora-sdxl) LoRA.
-
-```python
-from diffusers import DiffusionPipeline
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- variant="fp16",
- torch_dtype=torch.float16,
- unet=unet
-).to("cuda")
-pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
-```
-
-Now you'll create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
-
-```python
-from peft import get_peft_model, LoraConfig
-import copy
-
-sdxl_unet = copy.deepcopy(unet)
-ikea_peft_model = get_peft_model(
- sdxl_unet,
- pipeline.unet.peft_config["ikea"],
- adapter_name="ikea"
-)
-
-original_state_dict = {f"base_model.model.{k}": v for k, v in pipeline.unet.state_dict().items()}
-ikea_peft_model.load_state_dict(original_state_dict, strict=True)
-```
-
-> [!TIP]
-> You can optionally push the ikea_peft_model to the Hub by calling `ikea_peft_model.push_to_hub("ikea_peft_model", token=TOKEN)`.
-
-Repeat this process to create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
-
-```python
-pipeline.delete_adapters("ikea")
-sdxl_unet.delete_adapters("ikea")
-
-pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
-pipeline.set_adapters(adapter_names="feng")
-
-feng_peft_model = get_peft_model(
- sdxl_unet,
- pipeline.unet.peft_config["feng"],
- adapter_name="feng"
-)
-
-original_state_dict = {f"base_model.model.{k}": v for k, v in pipe.unet.state_dict().items()}
-feng_peft_model.load_state_dict(original_state_dict, strict=True)
-```
-
-2. Load a base UNet model and then load the adapters onto it.
-
-```python
-from peft import PeftModel
-
-base_unet = UNet2DConditionModel.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16,
- use_safetensors=True,
- variant="fp16",
- subfolder="unet",
-).to("cuda")
-
-model = PeftModel.from_pretrained(base_unet, "stevhliu/ikea_peft_model", use_safetensors=True, subfolder="ikea", adapter_name="ikea")
-model.load_adapter("stevhliu/feng_peft_model", use_safetensors=True, subfolder="feng", adapter_name="feng")
-```
-
-3. Merge the adapters using the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
-
-> [!WARNING]
-> Keep in mind the LoRAs need to have the same rank to be merged!
-
-```python
-model.add_weighted_adapter(
- adapters=["ikea", "feng"],
- weights=[1.0, 1.0],
- combination_type="dare_linear",
- adapter_name="ikea-feng"
-)
-model.set_adapters("ikea-feng")
-```
-
-Now you can generate an image with the merged LoRA.
-
-```python
-model = model.to(dtype=torch.float16, device="cuda")
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", unet=model, variant="fp16", torch_dtype=torch.float16,
-).to("cuda")
-
-image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
-image
-```
-
-
-
-
-
-## fuse_lora
-
-Both the [`~loaders.PeftAdapterMixin.set_adapters`] and [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
-
-You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
-
-For example, if you have a base model and adapters loaded and set as active with the following adapter weights:
-
-```py
-from diffusers import DiffusionPipeline
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
-pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
-
-pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
-```
-
-Fuse these LoRAs into the UNet with the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method because it won’t work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
-
-```py
-pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
-```
-
-Then you should use [`~loaders.StableDiffusionLoraLoaderMixin.unload_lora_weights`] to unload the LoRA weights since they've already been fused with the underlying base model. Finally, call [`~DiffusionPipeline.save_pretrained`] to save the fused pipeline locally or you could call [`~DiffusionPipeline.push_to_hub`] to push the fused pipeline to the Hub.
-
-```py
-pipeline.unload_lora_weights()
-# save locally
-pipeline.save_pretrained("path/to/fused-pipeline")
-# save to the Hub
-pipeline.push_to_hub("fused-ikea-feng")
-```
-
-Now you can quickly load the fused pipeline and use it for inference without needing to separately load the LoRA adapters.
-
-```py
-pipeline = DiffusionPipeline.from_pretrained(
- "username/fused-ikea-feng", torch_dtype=torch.float16,
-).to("cuda")
-
-image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
-image
-```
-
-You can call [`~~loaders.lora_base.LoraBaseMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
-
-```py
-pipeline.unfuse_lora()
-```
-
-### torch.compile
-
-[torch.compile](../optimization/torch2.0#torchcompile) can speed up your pipeline even more, but the LoRA weights must be fused first and then unloaded. Typically, the UNet is compiled because it is such a computationally intensive component of the pipeline.
-
-```py
-from diffusers import DiffusionPipeline
-import torch
-
-# load base model and LoRAs
-pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
-pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_XL.safetensors", adapter_name="feng")
-
-# activate both LoRAs and set adapter weights
-pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
-
-# fuse LoRAs and unload weights
-pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
-pipeline.unload_lora_weights()
-
-# torch.compile
-pipeline.unet.to(memory_format=torch.channels_last)
-pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
-
-image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai", generator=torch.manual_seed(0)).images[0]
-```
-
-Learn more about torch.compile in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion#torchcompile) guide.
-
-## Next steps
-
-For more conceptual details about how each merging method works, take a look at the [🤗 PEFT welcomes new merging methods](https://huggingface.co/blog/peft_merging#concatenation-cat) blog post!
diff --git a/docs/source/en/using-diffusers/omnigen.md b/docs/source/en/using-diffusers/omnigen.md
index 40a9e81bcd52..2880fedb3392 100644
--- a/docs/source/en/using-diffusers/omnigen.md
+++ b/docs/source/en/using-diffusers/omnigen.md
@@ -1,4 +1,4 @@
-
-# Model files and layouts
-
[[open-in-colab]]
-Diffusion models are saved in various file types and organized in different layouts. Diffusers stores model weights as safetensors files in *Diffusers-multifolder* layout and it also supports loading files (like safetensors and ckpt files) from a *single-file* layout which is commonly used in the diffusion ecosystem.
-
-Each layout has its own benefits and use cases, and this guide will show you how to load the different files and layouts, and how to convert them.
+# Model formats
-## Files
-
-PyTorch model weights are typically saved with Python's [pickle](https://docs.python.org/3/library/pickle.html) utility as ckpt or bin files. However, pickle is not secure and pickled files may contain malicious code that can be executed. This vulnerability is a serious concern given the popularity of model sharing. To address this security issue, the [Safetensors](https://hf.co/docs/safetensors) library was developed as a secure alternative to pickle, which saves models as safetensors files.
-
-### safetensors
+Diffusion models are typically stored in the Diffusers format or single-file format. Model files can be stored in various file types such as safetensors, dduf, or ckpt.
> [!TIP]
-> Learn more about the design decisions and why safetensor files are preferred for saving and loading model weights in the [Safetensors audited as really safe and becoming the default](https://blog.eleuther.ai/safetensors-security-audit/) blog post.
+> Format refers to whether the weights are stored in a directory structure and file refers to the file type.
-[Safetensors](https://hf.co/docs/safetensors) is a safe and fast file format for securely storing and loading tensors. Safetensors restricts the header size to limit certain types of attacks, supports lazy loading (useful for distributed setups), and has generally faster loading speeds.
+This guide will show you how to load pipelines and models from these formats and files.
-Make sure you have the [Safetensors](https://hf.co/docs/safetensors) library installed.
+## Diffusers format
-```py
-!pip install safetensors
-```
+The Diffusers format stores each model (UNet, transformer, text encoder) in a separate subfolder. There are several benefits to storing models separately.
-Safetensors stores weights in a safetensors file. Diffusers loads safetensors files by default if they're available and the Safetensors library is installed. There are two ways safetensors files can be organized:
+- Faster overall pipeline initialization because you can load the individual model you need or load them all in parallel.
+- Reduced memory usage because you don't need to load all the pipeline components if you only need one model. [Reuse](./loading#reusing-models-in-multiple-pipelines) a model that is shared between multiple pipelines.
+- Lower storage requirements because common models shared between multiple pipelines are only downloaded once.
+- Flexibility to use new or improved models in a pipeline.
-1. Diffusers-multifolder layout: there may be several separate safetensors files, one for each pipeline component (text encoder, UNet, VAE), organized in subfolders (check out the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main) repository as an example)
-2. single-file layout: all the model weights may be saved in a single file (check out the [WarriorMama777/OrangeMixs](https://hf.co/WarriorMama777/OrangeMixs/tree/main/Models/AbyssOrangeMix) repository as an example)
+## Single file format
-
-
+A single-file format stores *all* the model (UNet, transformer, text encoder) weights in a single file. Benefits of single-file formats include the following.
-Use the [`~DiffusionPipeline.from_pretrained`] method to load a model with safetensors files stored in multiple folders.
+- Greater compatibility with [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
+- Easier to download and share a single file.
+
+Use [`~loaders.FromSingleFileMixin.from_single_file`] to load a single file.
```py
-from diffusers import DiffusionPipeline
+import torch
+from diffusers import StableDiffusionXLPipeline
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- use_safetensors=True
+pipeline = StableDiffusionXLPipeline.from_single_file(
+ "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
+ torch_dtype=torch.float16,
+ device_map="cuda"
)
```
-
-
-
-Use the [`~loaders.FromSingleFileMixin.from_single_file`] method to load a model with all the weights stored in a single safetensors file.
+The [`~loaders.FromSingleFileMixin.from_single_file`] method also supports passing new models or schedulers.
```py
-from diffusers import StableDiffusionPipeline
+import torch
+from diffusers import FluxPipeline, FluxTransformer2DModel
-pipeline = StableDiffusionPipeline.from_single_file(
- "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
+transformer = FluxTransformer2DModel.from_single_file(
+ "https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=torch.bfloat16
+)
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
)
```
-
-
+### Configuration options
-#### LoRA files
+Diffusers format models have a `config.json` file in their repositories with important attributes such as the number of layers and attention heads. The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically determines the appropriate config to use from `config.json`. This may fail in a few rare instances though, in which case, you should use the `config` argument.
-[LoRA](https://hf.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a lightweight adapter that is fast and easy to train, making them especially popular for generating images in a certain way or style. These adapters are commonly stored in a safetensors file, and are widely popular on model sharing platforms like [civitai](https://civitai.com/).
-
-LoRAs are loaded into a base model with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method.
+You should also use the `config` argument if the models in a pipeline are different from the original implementation or if it doesn't have the necessary metadata to determine the correct config.
```py
from diffusers import StableDiffusionXLPipeline
-import torch
-
-# base model
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-
-# download LoRA weights
-!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
-# load LoRA weights
-pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
-prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
-negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
+ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
-image = pipeline(
- prompt=prompt,
- negative_prompt=negative_prompt,
- generator=torch.manual_seed(0),
-).images[0]
-image
+pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, config="segmind/SSD-1B")
```
-
-
-
-
-### ckpt
-
-> [!WARNING]
-> Pickled files may be unsafe because they can be exploited to execute malicious code. It is recommended to use safetensors files instead where possible, or convert the weights to safetensors files.
-
-PyTorch's [torch.save](https://pytorch.org/docs/stable/generated/torch.save.html) function uses Python's [pickle](https://docs.python.org/3/library/pickle.html) utility to serialize and save models. These files are saved as a ckpt file and they contain the entire model's weights.
+Diffusers attempts to infer the pipeline components based on the signature types of the pipeline class when using `original_config` with `local_files_only=True`. It won't download the config files from a Hub repository to avoid backward breaking changes when you can't connect to the internet. This method isn't as reliable as providing a path to a local model with the `config` argument and may lead to errors. You should run the pipeline with `local_files_only=False` to download the config files to the local cache to avoid errors.
-Use the [`~loaders.FromSingleFileMixin.from_single_file`] method to directly load a ckpt file.
+Override default configs by passing the arguments directly to [`~loaders.FromSingleFileMixin.from_single_file`]. The examples below demonstrate how to override the configs in a pipeline or model.
```py
-from diffusers import StableDiffusionPipeline
+from diffusers import StableDiffusionXLInstructPix2PixPipeline
-pipeline = StableDiffusionPipeline.from_single_file(
- "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt"
+ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
+pipeline = StableDiffusionXLInstructPix2PixPipeline.from_single_file(
+ ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True
)
```
-## Storage layout
-
-There are two ways model files are organized, either in a Diffusers-multifolder layout or in a single-file layout. The Diffusers-multifolder layout is the default, and each component file (text encoder, UNet, VAE) is stored in a separate subfolder. Diffusers also supports loading models from a single-file layout where all the components are bundled together.
+```py
+from diffusers import UNet2DConditionModel
-### Diffusers-multifolder
+ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
+model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
+```
-The Diffusers-multifolder layout is the default storage layout for Diffusers. Each component's (text encoder, UNet, VAE) weights are stored in a separate subfolder. The weights can be stored as safetensors or ckpt files.
+### Local files
-
-
-
-
multifolder layout
-
-
-
-
UNet subfolder
-
-
+The [`~loaders.FromSingleFileMixin.from_single_file`] method attempts to configure a pipeline or model by inferring the model type from the keys in the checkpoint file. For example, any single file checkpoint based on the Stable Diffusion XL base model is configured from [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
-To load from Diffusers-multifolder layout, use the [`~DiffusionPipeline.from_pretrained`] method.
+If you're working with local files, download the config files with the [`~huggingface_hub.snapshot_download`] method and the model checkpoint with [`~huggingface_hub.hf_hub_download`]. These files are downloaded to your [cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache), but you can download them to a specific directory with the `local_dir` argument.
```py
-from diffusers import DiffusionPipeline
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True,
-).to("cuda")
-```
+from huggingface_hub import hf_hub_download, snapshot_download
+from diffusers import StableDiffusionXLPipeline
-Benefits of using the Diffusers-multifolder layout include:
-
-1. Faster to load each component file individually or in parallel.
-2. Reduced memory usage because you only load the components you need. For example, models like [SDXL Turbo](https://hf.co/stabilityai/sdxl-turbo), [SDXL Lightning](https://hf.co/ByteDance/SDXL-Lightning), and [Hyper-SD](https://hf.co/ByteDance/Hyper-SD) have the same components except for the UNet. You can reuse their shared components with the [`~DiffusionPipeline.from_pipe`] method without consuming any additional memory (take a look at the [Reuse a pipeline](./loading#reuse-a-pipeline) guide) and only load the UNet. This way, you don't need to download redundant components and unnecessarily use more memory.
-
- ```py
- import torch
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
-
- # download one model
- sdxl_pipeline = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True,
- ).to("cuda")
-
- # switch UNet for another model
- unet = UNet2DConditionModel.from_pretrained(
- "stabilityai/sdxl-turbo",
- subfolder="unet",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True
- )
- # reuse all the same components in new model except for the UNet
- turbo_pipeline = StableDiffusionXLPipeline.from_pipe(
- sdxl_pipeline, unet=unet,
- ).to("cuda")
- turbo_pipeline.scheduler = EulerDiscreteScheduler.from_config(
- turbo_pipeline.scheduler.config,
- timestep+spacing="trailing"
- )
- image = turbo_pipeline(
- "an astronaut riding a unicorn on mars",
- num_inference_steps=1,
- guidance_scale=0.0,
- ).images[0]
- image
- ```
-
-3. Reduced storage requirements because if a component, such as the SDXL [VAE](https://hf.co/madebyollin/sdxl-vae-fp16-fix), is shared across multiple models, you only need to download and store a single copy of it instead of downloading and storing it multiple times. For 10 SDXL models, this can save ~3.5GB of storage. The storage savings is even greater for newer models like PixArt Sigma, where the [text encoder](https://hf.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS/tree/main/text_encoder) alone is ~19GB!
-4. Flexibility to replace a component in the model with a newer or better version.
-
- ```py
- from diffusers import DiffusionPipeline, AutoencoderKL
-
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- vae=vae,
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True,
- ).to("cuda")
- ```
-
-5. More visibility and information about a model's components, which are stored in a [config.json](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json) file in each component subfolder.
-
-### Single-file
-
-The single-file layout stores all the model weights in a single file. All the model components (text encoder, UNet, VAE) weights are kept together instead of separately in subfolders. This can be a safetensors or ckpt file.
-
-
-
-
-
-To load from a single-file layout, use the [`~loaders.FromSingleFileMixin.from_single_file`] method.
+my_local_checkpoint_path = hf_hub_download(
+ repo_id="segmind/SSD-1B",
+ filename="SSD-1B.safetensors"
+)
-```py
-import torch
-from diffusers import StableDiffusionXLPipeline
+my_local_config_path = snapshot_download(
+ repo_id="segmind/SSD-1B",
+ allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
+)
pipeline = StableDiffusionXLPipeline.from_single_file(
- "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True,
-).to("cuda")
+ my_local_checkpoint_path, config=my_local_config_path, local_files_only=True
+)
```
-Benefits of using a single-file layout include:
-
-1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
-2. Easier to manage (download and share) a single file.
-
-### DDUF
-
-> [!WARNING]
-> DDUF is an experimental file format and APIs related to it can change in the future.
-
-DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.
-
-Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).
-
-Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].
-
-```py
-from diffusers import DiffusionPipeline
-import torch
+### Symlink
-pipe = DiffusionPipeline.from_pretrained(
- "DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
-).to("cuda")
-image = pipe(
- "photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
-).images[0]
-image.save("cat.png")
-```
-
-To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.
+If you're working with a file system that does not support symlinking, download the checkpoint file to a local directory first with the `local_dir` parameter. Using the `local_dir` parameter automatically disables symlinks.
```py
-from huggingface_hub import export_folder_as_dduf
-from diffusers import DiffusionPipeline
-import torch
-
-pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
-
-save_folder = "flux-dev"
-pipe.save_pretrained("flux-dev")
-export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)
-
-> [!TIP]
-> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.
-
-## Convert layout and files
-
-Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
-
-Take a look at the [diffusers/scripts](https://github.com/huggingface/diffusers/tree/main/scripts) collection to find a script that fits your conversion needs.
-
-> [!TIP]
-> Scripts that have "`to_diffusers`" appended at the end mean they convert a model to the Diffusers-multifolder layout. Each script has their own specific set of arguments for configuring the conversion, so make sure you check what arguments are available!
+from huggingface_hub import hf_hub_download, snapshot_download
+from diffusers import StableDiffusionXLPipeline
-For example, to convert a Stable Diffusion XL model stored in Diffusers-multifolder layout to a single-file layout, run the [convert_diffusers_to_original_sdxl.py](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_sdxl.py) script. Provide the path to the model to convert, and the path to save the converted model to. You can optionally specify whether you want to save the model as a safetensors file and whether to save the model in half-precision.
+my_local_checkpoint_path = hf_hub_download(
+ repo_id="segmind/SSD-1B",
+ filename="SSD-1B.safetensors"
+ local_dir="my_local_checkpoints",
+)
+print("My local checkpoint: ", my_local_checkpoint_path)
-```bash
-python convert_diffusers_to_original_sdxl.py --model_path path/to/model/to/convert --checkpoint_path path/to/save/model/to --use_safetensors
+my_local_config_path = snapshot_download(
+ repo_id="segmind/SSD-1B",
+ allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
+)
+print("My local config: ", my_local_config_path)
```
-You can also save a model to Diffusers-multifolder layout with the [`~DiffusionPipeline.save_pretrained`] method. This creates a directory for you if it doesn't already exist, and it also saves the files as a safetensors file by default.
+Pass these paths to [`~loaders.FromSingleFileMixin.from_single_file`].
```py
-from diffusers import StableDiffusionXLPipeline
-
pipeline = StableDiffusionXLPipeline.from_single_file(
- "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
+ my_local_checkpoint_path, config=my_local_config_path, local_files_only=True
)
-pipeline.save_pretrained()
```
-Lastly, there are also Spaces, such as [SD To Diffusers](https://hf.co/spaces/diffusers/sd-to-diffusers) and [SD-XL To Diffusers](https://hf.co/spaces/diffusers/sdxl-to-diffusers), that provide a more user-friendly interface for converting models to Diffusers-multifolder layout. This is the easiest and most convenient option for converting layouts, and it'll open a PR on your model repository with the converted files. However, this option is not as reliable as running a script, and the Space may fail for more complicated models.
+## File types
-## Single-file layout usage
+Models can be stored in several file types. Safetensors is the most common file type but you may encounter other file types on the Hub or diffusion community.
-Now that you're familiar with the differences between the Diffusers-multifolder and single-file layout, this section shows you how to load models and pipeline components, customize configuration options for loading, and load local files with the [`~loaders.FromSingleFileMixin.from_single_file`] method.
+### safetensors
-### Load a pipeline or model
+[Safetensors](https://hf.co/docs/safetensors) is a safe and fast file type for securely storing and loading tensors. It restricts the header size to limit certain types of attacks, supports lazy loading (useful for distributed setups), and generally loads faster.
-Pass the file path of the pipeline or model to the [`~loaders.FromSingleFileMixin.from_single_file`] method to load it.
+Diffusers loads safetensors file by default (a required dependency) if they are available and the Safetensors library is installed.
-
-
+Use [`~DiffusionPipeline.from_pretrained`] or [`~loaders.FromSingleFileMixin.from_single_file`] to load safetensor files.
```py
-from diffusers import StableDiffusionXLPipeline
-
-ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
-pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path)
-```
-
-
-
+import torch
+from diffusers import DiffusionPipeline
-```py
-from diffusers import StableCascadeUNet
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch.dtype=torch.float16,
+ device_map="cuda"
+)
-ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
-model = StableCascadeUNet.from_single_file(ckpt_path)
+pipeline = DiffusionPipeline.from_single_file(
+ "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
+ torch_dtype=torch.float16,
+)
```
-
-
-
-Customize components in the pipeline by passing them directly to the [`~loaders.FromSingleFileMixin.from_single_file`] method. For example, you can use a different scheduler in a pipeline.
-
-```py
-from diffusers import StableDiffusionXLPipeline, DDIMScheduler
-
-ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
-scheduler = DDIMScheduler()
-pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, scheduler=scheduler)
-```
+If you're using a checkpoint trained with a Diffusers training script, metadata such as the LoRA configuration, is automatically saved. When the file is loaded, the metadata is parsed to correctly configure the LoRA and avoid missing or incorrect LoRA configs. Inspect the metadata of a safetensors file by clicking on the logo next to the file on the Hub.
-Or you could use a ControlNet model in the pipeline.
+Save the metadata for LoRAs that aren't trained with Diffusers with either `transformer_lora_adapter_metadata` or `unet_lora_adapter_metadata` depending on your model. For the text encoder, use the `text_encoder_lora_adapter_metadata` and `text_encoder_2_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`]. This is only supported for safetensors files.
```py
-from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+import torch
+from diffusers import FluxPipeline
-ckpt_path = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
-controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
-pipeline = StableDiffusionControlNetPipeline.from_single_file(ckpt_path, controlnet=controlnet)
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+).to("cuda")
+pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
+pipeline.save_lora_weights(
+ text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8},
+ text_encoder_2_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
+)
```
-### Customize configuration options
-
-Models have a configuration file that define their attributes like the number of inputs in a UNet. Pipelines configuration options are available in the pipeline's class. For example, if you look at the [`StableDiffusionXLInstructPix2PixPipeline`] class, there is an option to scale the image latents with the `is_cosxl_edit` parameter.
-
-These configuration files can be found in the models Hub repository or another location from which the configuration file originated (for example, a GitHub repository or locally on your device).
+### ckpt
-
-
+Older model weights are commonly saved with Python's [pickle](https://docs.python.org/3/library/pickle.html) utility in a ckpt file.
-> [!TIP]
-> The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically maps the checkpoint to the appropriate model repository, but there are cases where it is useful to use the `config` parameter. For example, if the model components in the checkpoint are different from the original checkpoint or if a checkpoint doesn't have the necessary metadata to correctly determine the configuration to use for the pipeline.
+Pickled files may be unsafe because they can be exploited to execute malicious code. It is recommended to use safetensors files or convert the weights to safetensors files.
-The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically determines the configuration to use from the configuration file in the model repository. You could also explicitly specify the configuration to use by providing the repository id to the `config` parameter.
+Use [`~loaders.FromSingleFileMixin.from_single_file`] to load a ckpt file.
```py
-from diffusers import StableDiffusionXLPipeline
-
-ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
-repo_id = "segmind/SSD-1B"
+from diffusers import DiffusionPipeline
-pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, config=repo_id)
+pipeline = DiffusionPipeline.from_single_file(
+ "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt"
+)
```
-The model loads the configuration file for the [UNet](https://huggingface.co/segmind/SSD-1B/blob/main/unet/config.json), [VAE](https://huggingface.co/segmind/SSD-1B/blob/main/vae/config.json), and [text encoder](https://huggingface.co/segmind/SSD-1B/blob/main/text_encoder/config.json) from their respective subfolders in the repository.
-
-
-
-
-The [`~loaders.FromSingleFileMixin.from_single_file`] method can also load the original configuration file of a pipeline that is stored elsewhere. Pass a local path or URL of the original configuration file to the `original_config` parameter.
-
-```py
-from diffusers import StableDiffusionXLPipeline
-
-ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
-original_config = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
-
-pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, original_config=original_config)
-```
+### dduf
> [!TIP]
-> Diffusers attempts to infer the pipeline components based on the type signatures of the pipeline class when you use `original_config` with `local_files_only=True`, instead of fetching the configuration files from the model repository on the Hub. This prevents backward breaking changes in code that can't connect to the internet to fetch the necessary configuration files.
->
-> This is not as reliable as providing a path to a local model repository with the `config` parameter, and might lead to errors during pipeline configuration. To avoid errors, run the pipeline with `local_files_only=False` once to download the appropriate pipeline configuration files to the local cache.
-
-
-
-
-While the configuration files specify the pipeline or models default parameters, you can override them by providing the parameters directly to the [`~loaders.FromSingleFileMixin.from_single_file`] method. Any parameter supported by the model or pipeline class can be configured in this way.
+> DDUF is an experimental file type and the API may change. Refer to the DDUF [docs](https://huggingface.co/docs/hub/dduf) to learn more.
-
-
+DDUF is a file type designed to unify different diffusion model distribution methods and weight-saving formats. It is a standardized and flexible method to package all components of a diffusion model into a single file, providing a balance between the Diffusers and single-file formats.
-For example, to scale the image latents in [`StableDiffusionXLInstructPix2PixPipeline`] pass the `is_cosxl_edit` parameter.
+Use the `dduf_file` argument in [`~DiffusionPipeline.from_pretrained`] to load a DDUF file. You can also load quantized dduf files as long as they are stored in the Diffusers format.
-```python
-from diffusers import StableDiffusionXLInstructPix2PixPipeline
+```py
+import torch
+from diffusers import DiffusionPipeline
-ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
-pipeline = StableDiffusionXLInstructPix2PixPipeline.from_single_file(ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True)
+pipeline = DiffusionPipeline.from_pretrained(
+ "DDUF/FLUX.1-dev-DDUF",
+ dduf_file="FLUX.1-dev.dduf",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
```
-
-
+To save a pipeline as a dduf file, use the [`~huggingface_hub.export_folder_as_dduf`] utility.
-For example, to upcast the attention dimensions in a [`UNet2DConditionModel`] pass the `upcast_attention` parameter.
+```py
+import torch
+from diffusers import DiffusionPipeline
+from huggingface_hub import export_folder_as_dduf
-```python
-from diffusers import UNet2DConditionModel
+pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
-ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
-model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
+save_folder = "flux-dev"
+pipeline.save_pretrained("flux-dev")
+export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)
```
-
-
-
-### Local files
-
-In Diffusers>=v0.28.0, the [`~loaders.FromSingleFileMixin.from_single_file`] method attempts to configure a pipeline or model by inferring the model type from the keys in the checkpoint file. The inferred model type is used to determine the appropriate model repository on the Hugging Face Hub to configure the model or pipeline.
-
-For example, any single file checkpoint based on the Stable Diffusion XL base model will use the [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model repository to configure the pipeline.
+## Converting formats and files
-But if you're working in an environment with restricted internet access, you should download the configuration files with the [`~huggingface_hub.snapshot_download`] function, and the model checkpoint with the [`~huggingface_hub.hf_hub_download`] function. By default, these files are downloaded to the Hugging Face Hub [cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache), but you can specify a preferred directory to download the files to with the `local_dir` parameter.
+Diffusers provides scripts and methods to convert format and files to enable broader support across the diffusion ecosystem.
-Pass the configuration and checkpoint paths to the [`~loaders.FromSingleFileMixin.from_single_file`] method to load locally.
-
-
-
-
-```python
-from huggingface_hub import hf_hub_download, snapshot_download
-
-my_local_checkpoint_path = hf_hub_download(
- repo_id="segmind/SSD-1B",
- filename="SSD-1B.safetensors"
-)
+Take a look at the [diffusers/scripts](https://github.com/huggingface/diffusers/tree/main/scripts) folder to find a conversion script. Scripts with `"to_diffusers` appended at the end converts a model to the Diffusers format. Each script has a specific set of arguments for configuring the conversion. Make sure you check what arguments are available.
-my_local_config_path = snapshot_download(
- repo_id="segmind/SSD-1B",
- allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
-)
+The example below converts a model stored in Diffusers format to a single-file format. Provide the path to the model to convert and where to save the converted model. You can optionally specify what file type and data type to save the model as.
-pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
+```bash
+python convert_diffusers_to_original_sdxl.py --model_path path/to/model/to/convert --checkpoint_path path/to/save/model/to --use_safetensors
```
-
-
+The [`~DiffusionPipeline.save_pretrained`] method also saves a model in Diffusers format and takes care of creating subfolders for each model. It saves the files as safetensor files by default.
-```python
-from huggingface_hub import hf_hub_download, snapshot_download
-
-my_local_checkpoint_path = hf_hub_download(
- repo_id="segmind/SSD-1B",
- filename="SSD-1B.safetensors"
- local_dir="my_local_checkpoints"
-)
+```py
+from diffusers import DiffusionPipeline
-my_local_config_path = snapshot_download(
- repo_id="segmind/SSD-1B",
- allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
- local_dir="my_local_config"
+pipeline = DiffusionPipeline.from_single_file(
+ "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
)
-
-pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
+pipeline.save_pretrained()
```
-
-
+Finally, you can use a Space like [SD To Diffusers](https://hf.co/spaces/diffusers/sd-to-diffusers) or [SD-XL To Diffusers](https://hf.co/spaces/diffusers/sdxl-to-diffusers) to convert models to the Diffusers format. It'll open a PR on your model repository with the converted files. This is the easiest way to convert a model, but it may fail for more complicated models. Using a conversion script is more reliable.
-#### Local files without symlink
+## Resources
-> [!TIP]
-> In huggingface_hub>=v0.23.0, the `local_dir_use_symlinks` argument isn't necessary for the [`~huggingface_hub.hf_hub_download`] and [`~huggingface_hub.snapshot_download`] functions.
-
-The [`~loaders.FromSingleFileMixin.from_single_file`] method relies on the [huggingface_hub](https://hf.co/docs/huggingface_hub/index) caching mechanism to fetch and store checkpoints and configuration files for models and pipelines. If you're working with a file system that does not support symlinking, you should download the checkpoint file to a local directory first, and disable symlinking with the `local_dir_use_symlink=False` parameter in the [`~huggingface_hub.hf_hub_download`] function and [`~huggingface_hub.snapshot_download`] functions.
-
-```python
-from huggingface_hub import hf_hub_download, snapshot_download
+- Learn more about the design decisions and why safetensor files are preferred for saving and loading model weights in the [Safetensors audited as really safe and becoming the default](https://blog.eleuther.ai/safetensors-security-audit/) blog post.
-my_local_checkpoint_path = hf_hub_download(
- repo_id="segmind/SSD-1B",
- filename="SSD-1B.safetensors"
- local_dir="my_local_checkpoints",
- local_dir_use_symlinks=False
-)
-print("My local checkpoint: ", my_local_checkpoint_path)
-
-my_local_config_path = snapshot_download(
- repo_id="segmind/SSD-1B",
- allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
- local_dir_use_symlinks=False,
-)
-print("My local config: ", my_local_config_path)
-```
-
-Then you can pass the local paths to the `pretrained_model_link_or_path` and `config` parameters.
-
-```python
-pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
-```
diff --git a/docs/source/en/using-diffusers/overview_techniques.md b/docs/source/en/using-diffusers/overview_techniques.md
deleted file mode 100644
index d5ecf77b0fbd..000000000000
--- a/docs/source/en/using-diffusers/overview_techniques.md
+++ /dev/null
@@ -1,18 +0,0 @@
-
-
-# Overview
-
-The inference pipeline supports and enables a wide range of techniques that are divided into two categories:
-
-* Pipeline functionality: these techniques modify the pipeline or extend it for other applications. For example, pipeline callbacks add new features to a pipeline and a pipeline can also be extended for distributed inference.
-* Improve inference quality: these techniques increase the visual quality of the generated images. For example, you can enhance your prompts with GPT2 to create better images with lower effort.
diff --git a/docs/source/en/using-diffusers/pag.md b/docs/source/en/using-diffusers/pag.md
index 26961d959c49..c11a5dc379c8 100644
--- a/docs/source/en/using-diffusers/pag.md
+++ b/docs/source/en/using-diffusers/pag.md
@@ -1,4 +1,4 @@
-
-# Push files to the Hub
-
[[open-in-colab]]
-🤗 Diffusers provides a [`~diffusers.utils.PushToHubMixin`] for uploading your model, scheduler, or pipeline to the Hub. It is an easy way to store your files on the Hub, and also allows you to share your work with others. Under the hood, the [`~diffusers.utils.PushToHubMixin`]:
+# Sharing pipelines and models
+
+Share your pipeline or models and schedulers on the Hub with the [`~diffusers.utils.PushToHubMixin`] class. This class:
1. creates a repository on the Hub
2. saves your model, scheduler, or pipeline files so they can be reloaded later
3. uploads folder containing these files to the Hub
-This guide will show you how to use the [`~diffusers.utils.PushToHubMixin`] to upload your files to the Hub.
+This guide will show you how to upload your files to the Hub with the [`~diffusers.utils.PushToHubMixin`] class.
+
+Log in to your Hugging Face account with your access [token](https://huggingface.co/settings/tokens).
-You'll need to log in to your Hub account with your access [token](https://huggingface.co/settings/tokens) first:
+
+
```py
from huggingface_hub import notebook_login
@@ -30,9 +33,19 @@ from huggingface_hub import notebook_login
notebook_login()
```
+
+
+
+```bash
+hf auth login
+```
+
+
+
+
## Models
-To push a model to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the model to be stored on the Hub:
+To push a model to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the model.
```py
from diffusers import ControlNetModel
@@ -48,15 +61,9 @@ controlnet = ControlNetModel(
controlnet.push_to_hub("my-controlnet-model")
```
-For models, you can also specify the [*variant*](loading#checkpoint-variants) of the weights to push to the Hub. For example, to push `fp16` weights:
-
-```py
-controlnet.push_to_hub("my-controlnet-model", variant="fp16")
-```
-
-The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves the model's `config.json` file and the weights are automatically saved in the `safetensors` format.
+The [`~diffusers.utils.PushToHubMixin.push_to_hub`] method saves the model's `config.json` file and the weights are automatically saved as safetensors files.
-Now you can reload the model from your repository on the Hub:
+Load the model again with [`~DiffusionPipeline.from_pretrained`].
```py
model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model")
@@ -64,7 +71,7 @@ model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model")
## Scheduler
-To push a scheduler to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the scheduler to be stored on the Hub:
+To push a scheduler to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the scheduler.
```py
from diffusers import DDIMScheduler
@@ -81,7 +88,7 @@ scheduler.push_to_hub("my-controlnet-scheduler")
The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves the scheduler's `scheduler_config.json` file to the specified repository.
-Now you can reload the scheduler from your repository on the Hub:
+Load the scheduler again with [`~SchedulerMixin.from_pretrained`].
```py
scheduler = DDIMScheduler.from_pretrained("your-namepsace/my-controlnet-scheduler")
@@ -89,7 +96,7 @@ scheduler = DDIMScheduler.from_pretrained("your-namepsace/my-controlnet-schedule
## Pipeline
-You can also push an entire pipeline with all it's components to the Hub. For example, initialize the components of a [`StableDiffusionPipeline`] with the parameters you want:
+To push a pipeline to the Hub, initialize the pipeline components with your desired parameters.
```py
from diffusers import (
@@ -143,7 +150,7 @@ text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
```
-Pass all of the components to the [`StableDiffusionPipeline`] and call [`~diffusers.utils.PushToHubMixin.push_to_hub`] to push the pipeline to the Hub:
+Pass all components to the pipeline and call [`~diffusers.utils.PushToHubMixin.push_to_hub`].
```py
components = {
@@ -160,7 +167,7 @@ pipeline = StableDiffusionPipeline(**components)
pipeline.push_to_hub("my-pipeline")
```
-The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves each component to a subfolder in the repository. Now you can reload the pipeline from your repository on the Hub:
+The [`~diffusers.utils.PushToHubMixin.push_to_hub`] method saves each component to a subfolder in the repository. Load the pipeline again with [`~DiffusionPipeline.from_pretrained`].
```py
pipeline = StableDiffusionPipeline.from_pretrained("your-namespace/my-pipeline")
@@ -168,10 +175,10 @@ pipeline = StableDiffusionPipeline.from_pretrained("your-namespace/my-pipeline")
## Privacy
-Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] function to keep your model, scheduler, or pipeline files private:
+Set `private=True` in [`~diffusers.utils.PushToHubMixin.push_to_hub`] to keep a model, scheduler, or pipeline files private.
```py
controlnet.push_to_hub("my-controlnet-model-private", private=True)
```
-Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository.
\ No newline at end of file
+Private repositories are only visible to you. Other users won't be able to clone the repository and it won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md
index 1ad2a6b5c657..b4aed0aa6354 100644
--- a/docs/source/en/using-diffusers/reusing_seeds.md
+++ b/docs/source/en/using-diffusers/reusing_seeds.md
@@ -1,4 +1,4 @@
-
-# Reproducible pipelines
+# Reproducibility
-Diffusion models are inherently random which is what allows it to generate different outputs every time it is run. But there are certain times when you want to generate the same output every time, like when you're testing, replicating results, and even [improving image quality](#deterministic-batch-generation). While you can't expect to get identical results across platforms, you can expect reproducible results across releases and platforms within a certain tolerance range (though even this may vary).
+Diffusion is a random process that generates a different output every time. For certain situations like testing and replicating results, you want to generate the same result each time, across releases and platforms within a certain tolerance range.
-This guide will show you how to control randomness for deterministic generation on a CPU and GPU.
+This guide will show you how to control sources of randomness and enable deterministic algorithms.
-> [!TIP]
-> We strongly recommend reading PyTorch's [statement about reproducibility](https://pytorch.org/docs/stable/notes/randomness.html):
->
-> "Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds."
-
-## Control randomness
-
-During inference, pipelines rely heavily on random sampling operations which include creating the
-Gaussian noise tensors to denoise and adding noise to the scheduling step.
-
-Take a look at the tensor values in the [`DDIMPipeline`] after two inference steps.
-
-```python
-from diffusers import DDIMPipeline
-import numpy as np
+## Generator
-ddim = DDIMPipeline.from_pretrained( "google/ddpm-cifar10-32", use_safetensors=True)
-image = ddim(num_inference_steps=2, output_type="np").images
-print(np.abs(image).sum())
-```
-
-Running the code above prints one value, but if you run it again you get a different value.
-
-Each time the pipeline is run, [torch.randn](https://pytorch.org/docs/stable/generated/torch.randn.html) uses a different random seed to create the Gaussian noise tensors. This leads to a different result each time it is run and enables the diffusion pipeline to generate a different random image each time.
-
-But if you need to reliably generate the same image, that depends on whether you're running the pipeline on a CPU or GPU.
+Pipelines rely on [torch.randn](https://pytorch.org/docs/stable/generated/torch.randn.html), which uses a different random seed each time, to create the initial noisy tensors. To generate the same output on a CPU or GPU, use a [Generator](https://docs.pytorch.org/docs/stable/generated/torch.Generator.html) to manage how random values are generated.
> [!TIP]
-> It might seem unintuitive to pass `Generator` objects to a pipeline instead of the integer value representing the seed. However, this is the recommended design when working with probabilistic models in PyTorch because a `Generator` is a *random state* that can be passed to multiple pipelines in a sequence. As soon as the `Generator` is consumed, the *state* is changed in place which means even if you passed the same `Generator` to a different pipeline, it won't produce the same result because the state is already changed.
+> If reproducibility is important to your use case, we recommend always using a CPU `Generator`. The performance loss is often negligible and you'll generate more similar values.
-
-
+
+
-To generate reproducible results on a CPU, you'll need to use a PyTorch [Generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed. Now when you run the code, it always prints a value of `1491.1711` because the `Generator` object with the seed is passed to all the random functions in the pipeline. You should get a similar, if not the same, result on whatever hardware and PyTorch version you're using.
+The GPU uses a different random number generator than the CPU. Diffusers solves this issue with the [`~utils.torch_utils.randn_tensor`] function to create the random tensor on a CPU and then moving it to the GPU. This function is used everywhere inside the pipeline and you don't need to explicitly call it.
-```python
+Use [manual_seed](https://docs.pytorch.org/docs/stable/generated/torch.manual_seed.html) as shown below to set a seed.
+
+```py
import torch
import numpy as np
from diffusers import DDIMPipeline
-ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
-generator = torch.Generator(device="cpu").manual_seed(0)
+ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32", device_map="cuda")
+generator = torch.manual_seed(0)
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
print(np.abs(image).sum())
```
-
-
-Writing a reproducible pipeline on a GPU is a bit trickier, and full reproducibility across different hardware is not guaranteed because matrix multiplication - which diffusion pipelines require a lot of - is less deterministic on a GPU than a CPU. For example, if you run the same code example from the CPU example, you'll get a different result even though the seed is identical. This is because the GPU uses a different random number generator than the CPU.
-
-```python
-import torch
-import numpy as np
-from diffusers import DDIMPipeline
-
-ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
-ddim.to("cuda")
-generator = torch.Generator(device="cuda").manual_seed(0)
-image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
-print(np.abs(image).sum())
-```
+
-To avoid this issue, Diffusers has a [`~utils.torch_utils.randn_tensor`] function for creating random noise on the CPU, and then moving the tensor to a GPU if necessary. The [`~utils.torch_utils.randn_tensor`] function is used everywhere inside the pipeline. Now you can call [torch.manual_seed](https://pytorch.org/docs/stable/generated/torch.manual_seed.html) which automatically creates a CPU `Generator` that can be passed to the pipeline even if it is being run on a GPU.
+Set `device="cpu"` in the `Generator` and use [manual_seed](https://docs.pytorch.org/docs/stable/generated/torch.manual_seed.html) to set a seed for generating random numbers.
-```python
+```py
import torch
import numpy as np
from diffusers import DDIMPipeline
-ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
-ddim.to("cuda")
-generator = torch.manual_seed(0)
+ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32")
+generator = torch.Generator(device="cpu").manual_seed(0)
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
print(np.abs(image).sum())
```
-> [!TIP]
-> If reproducibility is important to your use case, we recommend always passing a CPU `Generator`. The performance loss is often negligible and you'll generate more similar values than if the pipeline had been run on a GPU.
-
-Finally, more complex pipelines such as [`UnCLIPPipeline`], are often extremely
-susceptible to precision error propagation. You'll need to use
-exactly the same hardware and PyTorch version for full reproducibility.
-
-## Deterministic algorithms
-
-You can also configure PyTorch to use deterministic algorithms to create a reproducible pipeline. The downside is that deterministic algorithms may be slower than non-deterministic ones and you may observe a decrease in performance.
-
-Non-deterministic behavior occurs when operations are launched in more than one CUDA stream. To avoid this, set the environment variable [CUBLAS_WORKSPACE_CONFIG](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during runtime.
-
-PyTorch typically benchmarks multiple algorithms to select the fastest one, but if you want reproducibility, you should disable this feature because the benchmark may select different algorithms each time. Set Diffusers [enable_full_determinism](https://github.com/huggingface/diffusers/blob/142f353e1c638ff1d20bd798402b68f72c1ebbdd/src/diffusers/utils/testing_utils.py#L861) to enable deterministic algorithms.
-
-```py
-enable_full_determinism()
-```
-
-Now when you run the same pipeline twice, you'll get identical results.
+The `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because it's *state* has changed.
```py
-import torch
-from diffusers import DDIMScheduler, StableDiffusionPipeline
-
-pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True).to("cuda")
-pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
-g = torch.Generator(device="cuda")
-
-prompt = "A bear is playing a guitar on Times Square"
-
-g.manual_seed(0)
-result1 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
-
-g.manual_seed(0)
-result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
+generator = torch.manual_seed(0)
-print("L_inf dist =", abs(result1 - result2).max())
-"L_inf dist = tensor(0., device='cuda:0')"
+for _ in range(5):
+- image = pipeline(prompt, generator=generator)
++ image = pipeline(prompt, generator=torch.manual_seed(0))
```
-## Deterministic batch generation
+## Deterministic algorithms
-A practical application of creating reproducible pipelines is *deterministic batch generation*. You generate a batch of images and select one image to improve with a more detailed prompt. The main idea is to pass a list of [Generator's](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed so you can reuse it.
+PyTorch supports [deterministic algorithms](https://docs.pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms) - where available - for certain operations so they produce the same results. Deterministic algorithms may be slower and decrease performance.
-Let's use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint and generate a batch of images.
+Use Diffusers' [enable_full_determinism](https://github.com/huggingface/diffusers/blob/142f353e1c638ff1d20bd798402b68f72c1ebbdd/src/diffusers/utils/testing_utils.py#L861) function to enable deterministic algorithms.
```py
import torch
-from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
+from diffusers_utils import enable_full_determinism
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-)
-pipeline = pipeline.to("cuda")
+enable_full_determinism()
```
-Define four different `Generator`s and assign each `Generator` a seed (`0` to `3`). Then generate a batch of images and pick one to iterate on.
+Under the hood, `enable_full_determinism` works by:
-> [!WARNING]
-> Use a list comprehension that iterates over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. If you multiply the `Generator` by the batch size integer, it only creates *one* `Generator` object that is used sequentially for each image in the batch.
->
-> ```py
-> [torch.Generator().manual_seed(seed)] * 4
-> ```
+- Setting the environment variable [CUBLAS_WORKSPACE_CONFIG](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during rntime. Non-deterministic behavior occurs when operations are used in more than one CUDA stream.
+- Disabling benchmarking to find the fastest convolution operation by setting `torch.backends.cudnn.benchmark=False`. Non-deterministic behavior occurs because the benchmark may select different algorithms each time depending on hardware or benchmarking noise.
+- Disabling TensorFloat32 (TF32) operations in favor of more precise and consistent full-precision operations.
-```python
-generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
-prompt = "Labrador in the style of Vermeer"
-images = pipeline(prompt, generator=generator, num_images_per_prompt=4).images[0]
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-
-
-Let's improve the first image (you can choose any image you want) which corresponds to the `Generator` with seed `0`. Add some additional text to your prompt and then make sure you reuse the same `Generator` with seed `0`. All the generated images should resemble the first image.
-
-```python
-prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
-generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
-images = pipeline(prompt, generator=generator).images
-make_image_grid(images, rows=2, cols=2)
-```
+## Resources
-
-
-
+We strongly recommend reading PyTorch's developer notes about [Reproducibility](https://docs.pytorch.org/docs/stable/notes/randomness.html). You can try to limit randomness, but it is not *guaranteed* even with an identical seed.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/scheduler_features.md b/docs/source/en/using-diffusers/scheduler_features.md
deleted file mode 100644
index 88be51a5c06e..000000000000
--- a/docs/source/en/using-diffusers/scheduler_features.md
+++ /dev/null
@@ -1,235 +0,0 @@
-
-
-# Scheduler features
-
-The scheduler is an important component of any diffusion model because it controls the entire denoising (or sampling) process. There are many types of schedulers, some are optimized for speed and some for quality. With Diffusers, you can modify the scheduler configuration to use custom noise schedules, sigmas, and rescale the noise schedule. Changing these parameters can have profound effects on inference quality and speed.
-
-This guide will demonstrate how to use these features to improve inference quality.
-
-> [!TIP]
-> Diffusers currently only supports the `timesteps` and `sigmas` parameters for a select list of schedulers and pipelines. Feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
-
-## Timestep schedules
-
-The timestep or noise schedule determines the amount of noise at each sampling step. The scheduler uses this to generate an image with the corresponding amount of noise at each step. The timestep schedule is generated from the scheduler's default configuration, but you can customize the scheduler to use new and optimized sampling schedules that aren't in Diffusers yet.
-
-For example, [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) is a method for optimizing a sampling schedule to generate a high-quality image in as little as 10 steps. The optimal [10-step schedule](https://github.com/huggingface/diffusers/blob/a7bf77fc284810483f1e60afe34d1d27ad91ce2e/src/diffusers/schedulers/scheduling_utils.py#L51) for Stable Diffusion XL is:
-
-```py
-from diffusers.schedulers import AysSchedules
-
-sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
-print(sampling_schedule)
-"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
-```
-
-You can use the AYS sampling schedule in a pipeline by passing it to the `timesteps` parameter.
-
-```py
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "SG161222/RealVisXL_V4.0",
- torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++")
-
-prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
-generator = torch.Generator(device="cpu").manual_seed(2487854446)
-image = pipeline(
- prompt=prompt,
- negative_prompt="",
- generator=generator,
- timesteps=sampling_schedule,
-).images[0]
-```
-
-
-
-
-
AYS timestep schedule 10 steps
-
-
-
-
Linearly-spaced timestep schedule 10 steps
-
-
-
-
Linearly-spaced timestep schedule 25 steps
-
-
-
-## Timestep spacing
-
-The way sample steps are selected in the schedule can affect the quality of the generated image, especially with respect to [rescaling the noise schedule](#rescale-noise-schedule), which can enable a model to generate much brighter or darker images. Diffusers provides three timestep spacing methods:
-
-- `leading` creates evenly spaced steps
-- `linspace` includes the first and last steps and evenly selects the remaining intermediate steps
-- `trailing` only includes the last step and evenly selects the remaining intermediate steps starting from the end
-
-It is recommended to use the `trailing` spacing method because it generates higher quality images with more details when there are fewer sample steps. But the difference in quality is not as obvious for more standard sample step values.
-
-```py
-import torch
-from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
-
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "SG161222/RealVisXL_V4.0",
- torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
-
-prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
-generator = torch.Generator(device="cpu").manual_seed(2487854446)
-image = pipeline(
- prompt=prompt,
- negative_prompt="",
- generator=generator,
- num_inference_steps=5,
-).images[0]
-image
-```
-
-
-
-
-
trailing spacing after 5 steps
-
-
-
-
leading spacing after 5 steps
-
-
-
-## Sigmas
-
-The `sigmas` parameter is the amount of noise added at each timestep according to the timestep schedule. Like the `timesteps` parameter, you can customize the `sigmas` parameter to control how much noise is added at each step. When you use a custom `sigmas` value, the `timesteps` are calculated from the custom `sigmas` value and the default scheduler configuration is ignored.
-
-For example, you can manually pass the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) for something like the 10-step AYS schedule from before to the pipeline.
-
-```py
-import torch
-
-from diffusers import DiffusionPipeline, EulerDiscreteScheduler
-
-model_id = "stabilityai/stable-diffusion-xl-base-1.0"
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
-pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
-
-sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
-prompt = "anthropomorphic capybara wearing a suit and working with a computer"
-generator = torch.Generator(device='cuda').manual_seed(123)
-image = pipeline(
- prompt=prompt,
- num_inference_steps=10,
- sigmas=sigmas,
- generator=generator
-).images[0]
-```
-
-When you take a look at the scheduler's `timesteps` parameter, you'll see that it is the same as the AYS timestep schedule because the `timestep` schedule is calculated from the `sigmas`.
-
-```py
-print(f" timesteps: {pipe.scheduler.timesteps}")
-"timesteps: tensor([999., 845., 730., 587., 443., 310., 193., 116., 53., 13.], device='cuda:0')"
-```
-
-### Karras sigmas
-
-> [!TIP]
-> Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas.
->
-> Karras sigmas should not be used for models that weren't trained with them. For example, the base Stable Diffusion XL model shouldn't use Karras sigmas but the [DreamShaperXL](https://hf.co/Lykon/dreamshaper-xl-1-0) model can since they are trained with Karras sigmas.
-
-Karras scheduler's use the timestep schedule and sigmas from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://hf.co/papers/2206.00364) paper. This scheduler variant applies a smaller amount of noise per step as it approaches the end of the sampling process compared to other schedulers, and can increase the level of details in the generated image.
-
-Enable Karras sigmas by setting `use_karras_sigmas=True` in the scheduler.
-
-```py
-import torch
-from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
-
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "SG161222/RealVisXL_V4.0",
- torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
-
-prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
-generator = torch.Generator(device="cpu").manual_seed(2487854446)
-image = pipeline(
- prompt=prompt,
- negative_prompt="",
- generator=generator,
-).images[0]
-```
-
-
-
-
-
Karras sigmas enabled
-
-
-
-
Karras sigmas disabled
-
-
-
-## Rescale noise schedule
-
-In the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://hf.co/papers/2305.08891) paper, the authors discovered that common noise schedules allowed some signal to leak into the last timestep. This signal leakage at inference can cause models to only generate images with medium brightness. By enforcing a zero signal-to-noise ratio (SNR) for the timstep schedule and sampling from the last timestep, the model can be improved to generate very bright or dark images.
-
-> [!TIP]
-> For inference, you need a model that has been trained with *v_prediction*. To train your own model with *v_prediction*, add the following flag to the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts.
->
-> ```bash
-> --prediction_type="v_prediction"
-> ```
-
-For example, load the [ptx0/pseudo-journey-v2](https://hf.co/ptx0/pseudo-journey-v2) checkpoint which was trained with `v_prediction` and the [`DDIMScheduler`]. Configure the following parameters in the [`DDIMScheduler`]:
-
-* `rescale_betas_zero_snr=True` to rescale the noise schedule to zero SNR
-* `timestep_spacing="trailing"` to start sampling from the last timestep
-
-Set `guidance_rescale` in the pipeline to prevent over-exposure. A lower value increases brightness but some of the details may appear washed out.
-
-```py
-from diffusers import DiffusionPipeline, DDIMScheduler
-
-pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", use_safetensors=True)
-
-pipeline.scheduler = DDIMScheduler.from_config(
- pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
-)
-pipeline.to("cuda")
-prompt = "cinematic photo of a snowy mountain at night with the northern lights aurora borealis overhead, 35mm photograph, film, professional, 4k, highly detailed"
-generator = torch.Generator(device="cpu").manual_seed(23)
-image = pipeline(prompt, guidance_rescale=0.7, generator=generator).images[0]
-image
-```
-
-
-
-
-
default Stable Diffusion v2-1 image
-
-
-
-
image with zero SNR and trailing timestep spacing enabled
-
-
diff --git a/docs/source/en/using-diffusers/schedulers.md b/docs/source/en/using-diffusers/schedulers.md
index 6972c6b6a1d9..0e236e4e3e1d 100644
--- a/docs/source/en/using-diffusers/schedulers.md
+++ b/docs/source/en/using-diffusers/schedulers.md
@@ -1,4 +1,4 @@
-
-# Load schedulers and models
-
[[open-in-colab]]
-Diffusion pipelines are a collection of interchangeable schedulers and models that can be mixed and matched to tailor a pipeline to a specific use case. The scheduler encapsulates the entire denoising process such as the number of denoising steps and the algorithm for finding the denoised sample. A scheduler is not parameterized or trained so they don't take very much memory. The model is usually only concerned with the forward pass of going from a noisy input to a less noisy sample.
+# Schedulers
+
+A scheduler is an algorithm that provides instructions to the denoising process such as how much noise to remove at a certain step. It takes the model prediction from step *t* and applies an update for how to compute the next sample at step *t-1*. Different schedulers produce different results; some are faster while others are more accurate.
+
+Diffusers supports many schedulers and allows you to modify their timestep schedules, timestep spacing, and more, to generate high-quality images in fewer steps.
-This guide will show you how to load schedulers and models to customize a pipeline. You'll use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint throughout this guide, so let's load it first.
+This guide will show you how to load and customize schedulers.
+
+## Loading schedulers
+
+Schedulers don't have any parameters and are defined in a configuration file. Access the `.scheduler` attribute of a pipeline to view the configuration.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-).to("cuda")
-```
-
-You can see what scheduler this pipeline uses with the `pipeline.scheduler` attribute.
-
-```py
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda"
+)
pipeline.scheduler
-PNDMScheduler {
- "_class_name": "PNDMScheduler",
- "_diffusers_version": "0.21.4",
- "beta_end": 0.012,
- "beta_schedule": "scaled_linear",
- "beta_start": 0.00085,
- "clip_sample": false,
- "num_train_timesteps": 1000,
- "set_alpha_to_one": false,
- "skip_prk_steps": true,
- "steps_offset": 1,
- "timestep_spacing": "leading",
- "trained_betas": null
-}
```
-## Load a scheduler
-
-Schedulers are defined by a configuration file that can be used by a variety of schedulers. Load a scheduler with the [`SchedulerMixin.from_pretrained`] method, and specify the `subfolder` parameter to load the configuration file into the correct subfolder of the pipeline repository.
-
-For example, to load the [`DDIMScheduler`]:
+Load a different scheduler with [`~SchedulerMixin.from_pretrained`] and specify the `subfolder` argument to load the configuration file into the correct subfolder of the pipeline repository. Pass the new scheduler to the existing pipeline.
```py
-from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers import DPMSolverMultistepScheduler
-ddim = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
+dpm = DPMSolverMultistepScheduler.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ scheduler=dpm,
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+pipeline.scheduler
```
-Then you can pass the newly loaded scheduler to the pipeline.
+## Timestep schedules
-```python
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True
-).to("cuda")
-```
+Timestep or noise schedule decides how noise is distributed over the denoising process. The schedule can be linear or more concentrated toward the beginning or end. It is a precomputed sequence of noise levels generated from the scheduler's default configuration, but it can be customized to use other schedules.
-## Compare schedulers
+> [!TIP]
+> The `timesteps` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
-Schedulers have their own unique strengths and weaknesses, making it difficult to quantitatively compare which scheduler works best for a pipeline. You typically have to make a trade-off between denoising speed and denoising quality. We recommend trying out different schedulers to find one that works best for your use case. Call the `pipeline.scheduler.compatibles` attribute to see what schedulers are compatible with a pipeline.
+The example below uses the [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) schedule which can generate a high-quality image in 10 steps, significantly speeding up generation and reducing computation time.
-Let's compare the [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], and the [`DPMSolverMultistepScheduler`] on the following prompt and seed.
+Import the schedule and pass it to the `timesteps` argument in the pipeline.
```py
import torch
-from diffusers import DiffusionPipeline
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+from diffusers.schedulers import AysSchedules
+
+sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
+print(sampling_schedule)
+"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-).to("cuda")
+ "SG161222/RealVisXL_V4.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
+)
-prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
-generator = torch.Generator(device="cuda").manual_seed(8)
+prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
+image = pipeline(
+ prompt=prompt,
+ negative_prompt="",
+ timesteps=sampling_schedule,
+).images[0]
```
-To change the pipelines scheduler, use the [`~ConfigMixin.from_config`] method to load a different scheduler's `pipeline.scheduler.config` into the pipeline.
+
+
+
+
AYS timestep schedule 10 steps
+
+
+
+
Linearly-spaced timestep schedule 10 steps
+
+
+
+
Linearly-spaced timestep schedule 25 steps
+
+
+
+### Rescaling schedules
+
+Denoising should begin with pure noise and the signal-to-noise (SNR) ration should be zero. However, some models don't actually start from pure noise which makes it difficult to generate images at brightness extremes.
-
-
+> [!TIP]
+> Train your own model with `v_prediction` by adding the `--prediction_type="v_prediction"` flag to your training script. You can also [search](https://huggingface.co/search/full-text?q=v_prediction&type=model) for existing models trained with `v_prediction`.
-[`LMSDiscreteScheduler`] typically generates higher quality images than the default scheduler.
+To fix this, a model must be trained with `v_prediction`. If a model is trained with `v_prediction`, then enable the following arguments in the scheduler.
+
+- Set `rescale_betas_zero_snr=True` to rescale the noise schedule to the very last timestep with exactly zero SNR
+- Set `timestep_spacing="trailing"` to force sampling from the last timestep with pure noise
```py
-from diffusers import LMSDiscreteScheduler
+from diffusers import DiffusionPipeline, DDIMScheduler
-pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
+pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", device_map="cuda")
-
-
+pipeline.scheduler = DDIMScheduler.from_config(
+ pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
+)
+```
-[`EulerDiscreteScheduler`] can generate higher quality images in just 30 steps.
+Set `guidance_rescale` in the pipeline to avoid overexposed images. A lower value increases brightness, but some details may appear washed out.
```py
-from diffusers import EulerDiscreteScheduler
-
-pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
-image = pipeline(prompt, generator=generator).images[0]
-image
+prompt = """
+cinematic photo of a snowy mountain at night with the northern lights aurora borealis
+overhead, 35mm photograph, film, professional, 4k, highly detailed
+"""
+image = pipeline(prompt, guidance_rescale=0.7).images[0]
```
-
-
+
+
+
+
default Stable Diffusion v2-1 image
+
+
+
+
image with zero SNR and trailing timestep spacing enabled
+
+
-[`EulerAncestralDiscreteScheduler`] can generate higher quality images in just 30 steps.
+## Timestep spacing
-```py
-from diffusers import EulerAncestralDiscreteScheduler
+Timestep spacing refers to the specific steps *t* to sample from from the schedule. Diffusers provides three spacing types as shown below.
-pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
+| spacing strategy | spacing calculation | example timesteps |
+|---|---|---|
+| `leading` | evenly spaced steps | `[900, 800, 700, ..., 100, 0]` |
+| `linspace` | include first and last steps and evenly divide remaining intermediate steps | `[1000, 888.89, 777.78, ..., 111.11, 0]` |
+| `trailing` | include last step and evenly divide remaining intermediate steps beginning from the end | `[999, 899, 799, 699, 599, 499, 399, 299, 199, 99]` |
-
-
+Pass the spacing strategy to the `timestep_spacing` argument in the scheduler.
-[`DPMSolverMultistepScheduler`] provides a balance between speed and quality and can generate higher quality images in just 20 steps.
+> [!TIP]
+> The `trailing` strategy typically produces higher quality images with more details with fewer steps, but the difference in quality is not as obvious for more standard step values.
```py
-from diffusers import DPMSolverMultistepScheduler
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "SG161222/RealVisXL_V4.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, timestep_spacing="trailing"
+)
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
-image = pipeline(prompt, generator=generator).images[0]
+prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
+image = pipeline(
+ prompt=prompt,
+ negative_prompt="",
+ num_inference_steps=5,
+).images[0]
image
```
-
-
-
-
-
LMSDiscreteScheduler
+
+
trailing spacing after 5 steps
-
-
EulerDiscreteScheduler
-
-
-
-
-
-
EulerAncestralDiscreteScheduler
-
-
-
-
DPMSolverMultistepScheduler
+
+
leading spacing after 5 steps
-Most images look very similar and are comparable in quality. Again, it often comes down to your specific use case so a good approach is to run multiple different schedulers and compare the results.
+## Sigmas
-### Flax schedulers
+Sigmas is a measure of how noisy a sample is at a certain step as defined by the schedule. When using custom `sigmas`, the `timesteps` are calculated from these values instead of the default scheduler configuration.
-To compare Flax schedulers, you need to additionally load the scheduler state into the model parameters. For example, let's change the default scheduler in [`FlaxStableDiffusionPipeline`] to use the super fast [`FlaxDPMSolverMultistepScheduler`].
+> [!TIP]
+> The `sigmas` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
-> [!WARNING]
-> The [`FlaxLMSDiscreteScheduler`] and [`FlaxDDPMScheduler`] are not compatible with the [`FlaxStableDiffusionPipeline`] yet.
+Pass the custom sigmas to the `sigmas` argument in the pipeline. The example below uses the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) from the 10-step AYS schedule.
```py
-import jax
-import numpy as np
-from flax.jax_utils import replicate
-from flax.training.common_utils import shard
-from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
-
-scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- subfolder="scheduler"
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "SG161222/RealVisXL_V4.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
)
-pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- scheduler=scheduler,
- variant="bf16",
- dtype=jax.numpy.bfloat16,
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
)
-params["scheduler"] = scheduler_state
-```
-Then you can take advantage of Flax's compatibility with TPUs to generate a number of images in parallel. You'll need to make a copy of the model parameters for each available device and then split the inputs across them to generate your desired number of images.
-
-```py
-# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
-prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
-num_samples = jax.device_count()
-prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
-
-prng_seed = jax.random.PRNGKey(0)
-num_inference_steps = 25
-
-# shard inputs and rng
-params = replicate(params)
-prng_seed = jax.random.split(prng_seed, jax.device_count())
-prompt_ids = shard(prompt_ids)
-
-images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
+prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
+image = pipeline(
+ prompt=prompt,
+ negative_prompt="",
+ sigmas=sigmas,
+).images[0]
```
-## Models
+### Karras sigmas
-Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
+[Karras sigmas](https://huggingface.co/papers/2206.00364) resamples the noise schedule for more efficient sampling by clustering sigmas more densely in the middle of the sequence where structure reconstruction is critical, while using fewer sigmas at the beginning and end where noise changes have less impact. This can increase the level of details in a generated image.
-Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are stored in the [unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet) subfolder.
+Set `use_karras_sigmas=True` in the scheduler to enable it.
-```python
-from diffusers import UNet2DConditionModel
+```py
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "SG161222/RealVisXL_V4.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config,
+ algorithm_type="sde-dpmsolver++",
+ use_karras_sigmas=True,
+)
-unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
+prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
+image = pipeline(
+ prompt=prompt,
+ negative_prompt="",
+ sigmas=sigmas,
+).images[0]
```
-They can also be directly loaded from a [repository](https://huggingface.co/google/ddpm-cifar10-32/tree/main).
+
+
+
+
Karras sigmas enabled
+
+
+
+
Karras sigmas disabled
+
+
-```python
-from diffusers import UNet2DModel
+Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas. It should only be used for models trained with Karras sigmas.
-unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
-```
+## Choosing a scheduler
-To load and save model variants, specify the `variant` argument in [`ModelMixin.from_pretrained`] and [`ModelMixin.save_pretrained`].
+It's important to try different schedulers to find the best one for your use case. Here are a few recommendations to help you get started.
-```python
-from diffusers import UNet2DConditionModel
+- DPM++ 2M SDE Karras is generally a good all-purpose option.
+- [`TCDScheduler`] works well for distilled models.
+- [`FlowMatchEulerDiscreteScheduler`] and [`FlowMatchHeunDiscreteScheduler`] for FlowMatch models.
+- [`EulerDiscreteScheduler`] or [`EulerAncestralDiscreteScheduler`] for generating anime style images.
+- DPM++ 2M paired with [`LCMScheduler`] on SDXL for generating realistic images.
-unet = UNet2DConditionModel.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
-)
-unet.save_pretrained("./local-unet", variant="non_ema")
-```
+## Resources
+
+- Read the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) paper for more details about rescaling the noise schedule to enforce zero SNR.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/sdxl.md b/docs/source/en/using-diffusers/sdxl.md
index 9938d561052b..275394a03ca9 100644
--- a/docs/source/en/using-diffusers/sdxl.md
+++ b/docs/source/en/using-diffusers/sdxl.md
@@ -1,4 +1,4 @@
-
-
-# JAX/Flax
-
-[[open-in-colab]]
-
-🤗 Diffusers supports Flax for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform. This guide shows you how to run inference with Stable Diffusion using JAX/Flax.
-
-Before you begin, make sure you have the necessary libraries installed:
-
-```py
-# uncomment to install the necessary libraries in Colab
-#!pip install -q jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
-#!pip install -q diffusers
-```
-
-You should also make sure you're using a TPU backend. While JAX does not run exclusively on TPUs, you'll get the best performance on a TPU because each server has 8 TPU accelerators working in parallel.
-
-If you are running this guide in Colab, select *Runtime* in the menu above, select the option *Change runtime type*, and then select *TPU* under the *Hardware accelerator* setting. Import JAX and quickly check whether you're using a TPU:
-
-```python
-import jax
-import jax.tools.colab_tpu
-jax.tools.colab_tpu.setup_tpu()
-
-num_devices = jax.device_count()
-device_type = jax.devices()[0].device_kind
-
-print(f"Found {num_devices} JAX devices of type {device_type}.")
-assert (
- "TPU" in device_type,
- "Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
-)
-# Found 8 JAX devices of type Cloud TPU.
-```
-
-Great, now you can import the rest of the dependencies you'll need:
-
-```python
-import jax.numpy as jnp
-from jax import pmap
-from flax.jax_utils import replicate
-from flax.training.common_utils import shard
-
-from diffusers import FlaxStableDiffusionPipeline
-```
-
-## Load a model
-
-Flax is a functional framework, so models are stateless and parameters are stored outside of them. Loading a pretrained Flax pipeline returns *both* the pipeline and the model weights (or parameters). In this guide, you'll use `bfloat16`, a more efficient half-float type that is supported by TPUs (you can also use `float32` for full precision if you want).
-
-```python
-dtype = jnp.bfloat16
-pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- variant="bf16",
- dtype=dtype,
-)
-```
-
-## Inference
-
-TPUs usually have 8 devices working in parallel, so let's use the same prompt for each device. This means you can perform inference on 8 devices at once, with each device generating one image. As a result, you'll get 8 images in the same amount of time it takes for one chip to generate a single image!
-
-
-
-Learn more details in the [How does parallelization work?](#how-does-parallelization-work) section.
-
-
-
-After replicating the prompt, get the tokenized text ids by calling the `prepare_inputs` function on the pipeline. The length of the tokenized text is set to 77 tokens as required by the configuration of the underlying CLIP text model.
-
-```python
-prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
-prompt = [prompt] * jax.device_count()
-prompt_ids = pipeline.prepare_inputs(prompt)
-prompt_ids.shape
-# (8, 77)
-```
-
-Model parameters and inputs have to be replicated across the 8 parallel devices. The parameters dictionary is replicated with [`flax.jax_utils.replicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.jax_utils.html#flax.jax_utils.replicate) which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`.
-
-```python
-# parameters
-p_params = replicate(params)
-
-# arrays
-prompt_ids = shard(prompt_ids)
-prompt_ids.shape
-# (8, 1, 77)
-```
-
-This shape means each one of the 8 devices receives as an input a `jnp` array with shape `(1, 77)`, where `1` is the batch size per device. On TPUs with sufficient memory, you could have a batch size larger than `1` if you want to generate multiple images (per chip) at once.
-
-Next, create a random number generator to pass to the generation function. This is standard procedure in Flax, which is very serious and opinionated about random numbers. All functions that deal with random numbers are expected to receive a generator to ensure reproducibility, even when you're training across multiple distributed devices.
-
-The helper function below uses a seed to initialize a random number generator. As long as you use the same seed, you'll get the exact same results. Feel free to use different seeds when exploring results later in the guide.
-
-```python
-def create_key(seed=0):
- return jax.random.PRNGKey(seed)
-```
-
-The helper function, or `rng`, is split 8 times so each device receives a different generator and generates a different image.
-
-```python
-rng = create_key(0)
-rng = jax.random.split(rng, jax.device_count())
-```
-
-To take advantage of JAX's optimized speed on a TPU, pass `jit=True` to the pipeline to compile the JAX code into an efficient representation and to ensure the model runs in parallel across the 8 devices.
-
-
-
-You need to ensure all your inputs have the same shape in subsequent calls, otherwise JAX will need to recompile the code which is slower.
-
-
-
-The first inference run takes more time because it needs to compile the code, but subsequent calls (even with different inputs) are much faster. For example, it took more than a minute to compile on a TPU v2-8, but then it takes about **7s** on a future inference run!
-
-```py
-%%time
-images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
-
-# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
-# Wall time: 1min 29s
-```
-
-The returned array has shape `(8, 1, 512, 512, 3)` which should be reshaped to remove the second dimension and get 8 images of `512 × 512 × 3`. Then you can use the [`~utils.numpy_to_pil`] function to convert the arrays into images.
-
-```python
-from diffusers.utils import make_image_grid
-
-images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
-images = pipeline.numpy_to_pil(images)
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-## Using different prompts
-
-You don't necessarily have to use the same prompt on all devices. For example, to generate 8 different prompts:
-
-```python
-prompts = [
- "Labrador in the style of Hokusai",
- "Painting of a squirrel skating in New York",
- "HAL-9000 in the style of Van Gogh",
- "Times Square under water, with fish and a dolphin swimming around",
- "Ancient Roman fresco showing a man working on his laptop",
- "Close-up photograph of young black woman against urban background, high quality, bokeh",
- "Armchair in the shape of an avocado",
- "Clown astronaut in space, with Earth in the background",
-]
-
-prompt_ids = pipeline.prepare_inputs(prompts)
-prompt_ids = shard(prompt_ids)
-
-images = pipeline(prompt_ids, p_params, rng, jit=True).images
-images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
-images = pipeline.numpy_to_pil(images)
-
-make_image_grid(images, 2, 4)
-```
-
-
-
-## How does parallelization work?
-
-The Flax pipeline in 🤗 Diffusers automatically compiles the model and runs it in parallel on all available devices. Let's take a closer look at how that process works.
-
-JAX parallelization can be done in multiple ways. The easiest one revolves around using the [`jax.pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function to achieve single-program multiple-data (SPMD) parallelization. It means running several copies of the same code, each on different data inputs. More sophisticated approaches are possible, and you can go over to the JAX [documentation](https://jax.readthedocs.io/en/latest/index.html) to explore this topic in more detail if you are interested!
-
-`jax.pmap` does two things:
-
-1. Compiles (or "`jit`s") the code which is similar to `jax.jit()`. This does not happen when you call `pmap`, and only the first time the `pmap`ped function is called.
-2. Ensures the compiled code runs in parallel on all available devices.
-
-To demonstrate, call `pmap` on the pipeline's `_generate` method (this is a private method that generates images and may be renamed or removed in future releases of 🤗 Diffusers):
-
-```python
-p_generate = pmap(pipeline._generate)
-```
-
-After calling `pmap`, the prepared function `p_generate` will:
-
-1. Make a copy of the underlying function, `pipeline._generate`, on each device.
-2. Send each device a different portion of the input arguments (this is why it's necessary to call the *shard* function). In this case, `prompt_ids` has shape `(8, 1, 77, 768)` so the array is split into 8 and each copy of `_generate` receives an input with shape `(1, 77, 768)`.
-
-The most important thing to pay attention to here is the batch size (1 in this example), and the input dimensions that make sense for your code. You don't have to change anything else to make the code work in parallel.
-
-The first time you call the pipeline takes more time, but the calls afterward are much faster. The `block_until_ready` function is used to correctly measure inference time because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don't need to use that in your code; blocking occurs automatically when you want to use the result of a computation that has not yet been materialized.
-
-```py
-%%time
-images = p_generate(prompt_ids, p_params, rng)
-images = images.block_until_ready()
-
-# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
-# Wall time: 1min 15s
-```
-
-Check your image dimensions to see if they're correct:
-
-```python
-images.shape
-# (8, 1, 512, 512, 3)
-```
-
-## Resources
-
-To learn more about how JAX works with Stable Diffusion, you may be interested in reading:
-
-* [Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e](https://hf.co/blog/sdxl_jax)
diff --git a/docs/source/en/using-diffusers/svd.md b/docs/source/en/using-diffusers/svd.md
index 7852d81fa209..bd6d5c332c13 100644
--- a/docs/source/en/using-diffusers/svd.md
+++ b/docs/source/en/using-diffusers/svd.md
@@ -1,4 +1,4 @@
-
-# Textual inversion
+# Textual Inversion
-[[open-in-colab]]
+[Textual Inversion](https://huggingface.co/papers/2208.01618) is a method for generating personalized images of a concept. It works by fine-tuning a models word embeddings on 3-5 images of the concept (for example, pixel art) that is associated with a unique token (``). This allows you to use the `` token in your prompt to trigger the model to generate pixel art images.
-The [`StableDiffusionPipeline`] supports textual inversion, a technique that enables a model like Stable Diffusion to learn a new concept from just a few sample images. This gives you more control over the generated images and allows you to tailor the model towards specific concepts. You can get started quickly with a collection of community created concepts in the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer).
-
-This guide will show you how to run inference with textual inversion using a pre-learned concept from the Stable Diffusion Conceptualizer. If you're interested in teaching a model new concepts with textual inversion, take a look at the [Textual Inversion](../training/text_inversion) training guide.
-
-Import the necessary libraries:
+Textual Inversion weights are very lightweight and typically only a few KBs because they're only word embeddings. However, this also means the word embeddings need to be loaded after loading a model with [`~DiffusionPipeline.from_pretrained`].
```py
import torch
-from diffusers import StableDiffusionPipeline
-from diffusers.utils import make_image_grid
-```
-
-## Stable Diffusion 1 and 2
-
-Pick a Stable Diffusion checkpoint and a pre-learned concept from the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer):
-
-```py
-pretrained_model_name_or_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-repo_id_embeds = "sd-concepts-library/cat-toy"
-```
-
-Now you can load a pipeline, and pass the pre-learned concept to it:
+from diffusers import AutoPipelineForText2Image
-```py
-pipeline = StableDiffusionPipeline.from_pretrained(
- pretrained_model_name_or_path, torch_dtype=torch.float16, use_safetensors=True
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16
).to("cuda")
-
-pipeline.load_textual_inversion(repo_id_embeds)
```
-Create a prompt with the pre-learned concept by using the special placeholder token ``, and choose the number of samples and rows of images you'd like to generate:
+Load the word embeddings with [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] and include the unique token in the prompt to activate its generation.
```py
-prompt = "a grafitti in a favela wall with a on it"
-
-num_samples_per_row = 2
-num_rows = 2
-```
-
-Then run the pipeline (feel free to adjust the parameters like `num_inference_steps` and `guidance_scale` to see how they affect image quality), save the generated images and visualize them with the helper function you created at the beginning:
-
-```py
-all_images = []
-for _ in range(num_rows):
- images = pipeline(prompt, num_images_per_prompt=num_samples_per_row, num_inference_steps=50, guidance_scale=7.5).images
- all_images.extend(images)
-
-grid = make_image_grid(all_images, num_rows, num_samples_per_row)
-grid
+pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
+prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration, style"
+pipeline(prompt).images[0]
```
-
+
-## Stable Diffusion XL
+Textual Inversion can also be trained to learn *negative embeddings* to steer generation away from unwanted characteristics such as "blurry" or "ugly". It is useful for improving image quality.
-Stable Diffusion XL (SDXL) can also use textual inversion vectors for inference. In contrast to Stable Diffusion 1 and 2, SDXL has two text encoders so you'll need two textual inversion embeddings - one for each text encoder model.
-
-Let's download the SDXL textual inversion embeddings and have a closer look at it's structure:
+EasyNegative is a widely used negative embedding that contains multiple learned negative concepts. Load the negative embeddings and specify the file name and token associated with the negative embeddings. Pass the token to `negative_prompt` in your pipeline to activate it.
```py
-from huggingface_hub import hf_hub_download
-from safetensors.torch import load_file
-
-file = hf_hub_download("dn118/unaestheticXL", filename="unaestheticXLv31.safetensors")
-state_dict = load_file(file)
-state_dict
-```
-
-```
-{'clip_g': tensor([[ 0.0077, -0.0112, 0.0065, ..., 0.0195, 0.0159, 0.0275],
- ...,
- [-0.0170, 0.0213, 0.0143, ..., -0.0302, -0.0240, -0.0362]],
- 'clip_l': tensor([[ 0.0023, 0.0192, 0.0213, ..., -0.0385, 0.0048, -0.0011],
- ...,
- [ 0.0475, -0.0508, -0.0145, ..., 0.0070, -0.0089, -0.0163]],
-```
-
-There are two tensors, `"clip_g"` and `"clip_l"`.
-`"clip_g"` corresponds to the bigger text encoder in SDXL and refers to
-`pipe.text_encoder_2` and `"clip_l"` refers to `pipe.text_encoder`.
-
-Now you can load each tensor separately by passing them along with the correct text encoder and tokenizer
-to [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`]:
-
-```py
-from diffusers import AutoPipelineForText2Image
import torch
+from diffusers import AutoPipelineForText2Image
-pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16)
-pipe.to("cuda")
-
-pipe.load_textual_inversion(state_dict["clip_g"], token="unaestheticXLv31", text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
-pipe.load_textual_inversion(state_dict["clip_l"], token="unaestheticXLv31", text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
-
-# the embedding should be used as a negative embedding, so we pass it as a negative prompt
-generator = torch.Generator().manual_seed(33)
-image = pipe("a woman standing in front of a mountain", negative_prompt="unaestheticXLv31", generator=generator).images[0]
-image
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16
+).to("cuda")
+pipeline.load_textual_inversion(
+ "EvilEngine/easynegative",
+ weight_name="easynegative.safetensors",
+ token="easynegative"
+)
+prompt = "A cute brown bear eating a slice of pizza, stunning color scheme, masterpiece, illustration"
+negative_prompt = "easynegative"
+pipeline(prompt, negative_prompt).images[0]
```
+
+
+
+
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/unconditional_image_generation.md b/docs/source/en/using-diffusers/unconditional_image_generation.md
index 8767eab292c0..0add5bab6707 100644
--- a/docs/source/en/using-diffusers/unconditional_image_generation.md
+++ b/docs/source/en/using-diffusers/unconditional_image_generation.md
@@ -1,4 +1,4 @@
-
-# Prompt techniques
-
[[open-in-colab]]
-Prompts are important because they describe what you want a diffusion model to generate. The best prompts are detailed, specific, and well-structured to help the model realize your vision. But crafting a great prompt takes time and effort and sometimes it may not be enough because language and words can be imprecise. This is where you need to boost your prompt with other techniques, such as prompt enhancing and prompt weighting, to get the results you want.
-
-This guide will show you how you can use these prompt techniques to generate high-quality images with lower effort and adjust the weight of certain keywords in a prompt.
-
-## Prompt engineering
-
-> [!TIP]
-> This is not an exhaustive guide on prompt engineering, but it will help you understand the necessary parts of a good prompt. We encourage you to continue experimenting with different prompts and combine them in new ways to see what works best. As you write more prompts, you'll develop an intuition for what works and what doesn't!
-
-New diffusion models do a pretty good job of generating high-quality images from a basic prompt, but it is still important to create a well-written prompt to get the best results. Here are a few tips for writing a good prompt:
-
-1. What is the image *medium*? Is it a photo, a painting, a 3D illustration, or something else?
-2. What is the image *subject*? Is it a person, animal, object, or scene?
-3. What *details* would you like to see in the image? This is where you can get really creative and have a lot of fun experimenting with different words to bring your image to life. For example, what is the lighting like? What is the vibe and aesthetic? What kind of art or illustration style are you looking for? The more specific and precise words you use, the better the model will understand what you want to generate.
-
-
-
-
-
"A photo of a banana-shaped couch in a living room"
-
-
-
-
"A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the windows"
-
-
-
-## Prompt enhancing with GPT2
+# Prompting
-Prompt enhancing is a technique for quickly improving prompt quality without spending too much effort constructing one. It uses a model like GPT2 pretrained on Stable Diffusion text prompts to automatically enrich a prompt with additional important keywords to generate high-quality images.
+Prompts describes what a model should generate. Good prompts are detailed, specific, and structured and they generate better images and videos.
-The technique works by curating a list of specific keywords and forcing the model to generate those words to enhance the original prompt. This way, your prompt can be "a cat" and GPT2 can enhance the prompt to "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain quality sharp focus beautiful detailed intricate stunning amazing epic".
+This guide shows you how to write effective prompts and introduces techniques that make them stronger.
-> [!TIP]
-> You should also use a [*offset noise*](https://www.crosslabs.org//blog/diffusion-with-offset-noise) LoRA to improve the contrast in bright and dark images and create better lighting overall. This [LoRA](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors) is available from [stabilityai/stable-diffusion-xl-base-1.0](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0).
+## Writing good prompts
-Start by defining certain styles and a list of words (you can check out a more comprehensive list of [words](https://hf.co/LykosAI/GPT-Prompt-Expansion-Fooocus-v2/blob/main/positive.txt) and [styles](https://github.com/lllyasviel/Fooocus/tree/main/sdxl_styles) used by Fooocus) to enhance a prompt with.
+Every effective prompt needs three core elements.
-```py
-import torch
-from transformers import GenerationConfig, GPT2LMHeadModel, GPT2Tokenizer, LogitsProcessor, LogitsProcessorList
-from diffusers import StableDiffusionXLPipeline
-
-styles = {
- "cinematic": "cinematic film still of {prompt}, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
- "anime": "anime artwork of {prompt}, anime style, key visual, vibrant, studio anime, highly detailed",
- "photographic": "cinematic photo of {prompt}, 35mm photograph, film, professional, 4k, highly detailed",
- "comic": "comic of {prompt}, graphic illustration, comic art, graphic novel art, vibrant, highly detailed",
- "lineart": "line art drawing {prompt}, professional, sleek, modern, minimalist, graphic, line art, vector graphics",
- "pixelart": " pixel-art {prompt}, low-res, blocky, pixel art style, 8-bit graphics",
-}
-
-words = [
- "aesthetic", "astonishing", "beautiful", "breathtaking", "composition", "contrasted", "epic", "moody", "enhanced",
- "exceptional", "fascinating", "flawless", "glamorous", "glorious", "illumination", "impressive", "improved",
- "inspirational", "magnificent", "majestic", "hyperrealistic", "smooth", "sharp", "focus", "stunning", "detailed",
- "intricate", "dramatic", "high", "quality", "perfect", "light", "ultra", "highly", "radiant", "satisfying",
- "soothing", "sophisticated", "stylish", "sublime", "terrific", "touching", "timeless", "wonderful", "unbelievable",
- "elegant", "awesome", "amazing", "dynamic", "trendy",
-]
-```
+1. Subject - what you want to generate. Start your prompt here.
+2. Style - the medium or aesthetic. How should it look?
+3. Context - details about actions, setting, and mood.
-You may have noticed in the `words` list, there are certain words that can be paired together to create something more meaningful. For example, the words "high" and "quality" can be combined to create "high quality". Let's pair these words together and remove the words that can't be paired.
-
-```py
-word_pairs = ["highly detailed", "high quality", "enhanced quality", "perfect composition", "dynamic light"]
-
-def find_and_order_pairs(s, pairs):
- words = s.split()
- found_pairs = []
- for pair in pairs:
- pair_words = pair.split()
- if pair_words[0] in words and pair_words[1] in words:
- found_pairs.append(pair)
- words.remove(pair_words[0])
- words.remove(pair_words[1])
-
- for word in words[:]:
- for pair in pairs:
- if word in pair.split():
- words.remove(word)
- break
- ordered_pairs = ", ".join(found_pairs)
- remaining_s = ", ".join(words)
- return ordered_pairs, remaining_s
-```
+Use these elements as a structured narrative, not a keyword list. Modern models understand language better than keyword matching. Start simple, then add details.
-Next, implement a custom [`~transformers.LogitsProcessor`] class that assigns tokens in the `words` list a value of 0 and assigns tokens not in the `words` list a negative value so they aren't picked during generation. This way, generation is biased towards words in the `words` list. After a word from the list is used, it is also assigned a negative value so it isn't picked again.
-
-```py
-class CustomLogitsProcessor(LogitsProcessor):
- def __init__(self, bias):
- super().__init__()
- self.bias = bias
-
- def __call__(self, input_ids, scores):
- if len(input_ids.shape) == 2:
- last_token_id = input_ids[0, -1]
- self.bias[last_token_id] = -1e10
- return scores + self.bias
-
-word_ids = [tokenizer.encode(word, add_prefix_space=True)[0] for word in words]
-bias = torch.full((tokenizer.vocab_size,), -float("Inf")).to("cuda")
-bias[word_ids] = 0
-processor = CustomLogitsProcessor(bias)
-processor_list = LogitsProcessorList([processor])
-```
-
-Combine the prompt and the `cinematic` style prompt defined in the `styles` dictionary earlier.
-
-```py
-prompt = "a cat basking in the sun on a roof in Turkey"
-style = "cinematic"
-
-prompt = styles[style].format(prompt=prompt)
-prompt
-"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
-```
-
-Load a GPT2 tokenizer and model from the [Gustavosta/MagicPrompt-Stable-Diffusion](https://huggingface.co/Gustavosta/MagicPrompt-Stable-Diffusion) checkpoint (this specific checkpoint is trained to generate prompts) to enhance the prompt.
-
-```py
-tokenizer = GPT2Tokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
-model = GPT2LMHeadModel.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion", torch_dtype=torch.float16).to(
- "cuda"
-)
-model.eval()
-
-inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
-token_count = inputs["input_ids"].shape[1]
-max_new_tokens = 50 - token_count
-
-generation_config = GenerationConfig(
- penalty_alpha=0.7,
- top_k=50,
- eos_token_id=model.config.eos_token_id,
- pad_token_id=model.config.eos_token_id,
- pad_token=model.config.pad_token_id,
- do_sample=True,
-)
-
-with torch.no_grad():
- generated_ids = model.generate(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- max_new_tokens=max_new_tokens,
- generation_config=generation_config,
- logits_processor=proccesor_list,
- )
-```
-
-Then you can combine the input prompt and the generated prompt. Feel free to take a look at what the generated prompt (`generated_part`) is, the word pairs that were found (`pairs`), and the remaining words (`words`). This is all packed together in the `enhanced_prompt`.
-
-```py
-output_tokens = [tokenizer.decode(generated_id, skip_special_tokens=True) for generated_id in generated_ids]
-input_part, generated_part = output_tokens[0][: len(prompt)], output_tokens[0][len(prompt) :]
-pairs, words = find_and_order_pairs(generated_part, word_pairs)
-formatted_generated_part = pairs + ", " + words
-enhanced_prompt = input_part + ", " + formatted_generated_part
-enhanced_prompt
-["cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain quality sharp focus beautiful detailed intricate stunning amazing epic"]
-```
-
-Finally, load a pipeline and the offset noise LoRA with a *low weight* to generate an image with the enhanced prompt.
-
-```py
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-
-pipeline.load_lora_weights(
- "stabilityai/stable-diffusion-xl-base-1.0",
- weight_name="sd_xl_offset_example-lora_1.0.safetensors",
- adapter_name="offset",
-)
-pipeline.set_adapters(["offset"], adapter_weights=[0.2])
-
-image = pipeline(
- enhanced_prompt,
- width=1152,
- height=896,
- guidance_scale=7.5,
- num_inference_steps=25,
-).images[0]
-image
-```
+Context is especially important for creating better prompts. Try adding lighting, artistic details, and mood.
-
-
-
"a cat basking in the sun on a roof in Turkey"
+
+
+
A cute cat lounges on a leaf in a pool during a peaceful summer afternoon , in lofi art style, illustration .
-
-
-
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
+
+
+
A cute cat lounges on a floating leaf in a sparkling pool during a peaceful summer afternoon. Clear reflections ripple across the water, with sunlight casting soft, smooth highlights. The illustration is detailed and polished, with elegant lines and harmonious colors, evoking a relaxing, serene, and whimsical lofi mood, anime-inspired and visually comforting.
-## Prompt weighting
-
-Prompt weighting provides a way to emphasize or de-emphasize certain parts of a prompt, allowing for more control over the generated image. A prompt can include several concepts, which gets turned into contextualized text embeddings. The embeddings are used by the model to condition its cross-attention layers to generate an image (read the Stable Diffusion [blog post](https://huggingface.co/blog/stable_diffusion) to learn more about how it works).
-
-Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt embeddings is to use [Stable Diffusion Long Prompt Weighted Embedding](https://github.com/xhinker/sd_embed) (sd_embed). Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [negative_prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
-
-
+Be specific and add context. Use photography terms like lens type, focal length, camera angles, and depth of field.
-If your favorite pipeline doesn't have a `prompt_embeds` parameter, please open an [issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can add it!
-
-
+> [!TIP]
+> Try a [prompt enhancer](https://huggingface.co/models?sort=downloads&search=prompt+enhancer) to help improve your prompt structure.
-This guide will show you how to weight your prompts with sd_embed.
+## Prompt weighting
-Before you begin, make sure you have the latest version of sd_embed installed:
+Prompt weighting makes some words stronger and others weaker. It scales attention scores so you control how much influence each concept has.
-```bash
-pip install git+https://github.com/xhinker/sd_embed.git@main
-```
+Diffusers handles this through `prompt_embeds` and `pooled_prompt_embeds` arguments which take scaled text embedding vectors. Use the [sd_embed](https://github.com/xhinker/sd_embed) library to generate these embeddings. It also supports longer prompts.
-For this example, let's use [`StableDiffusionXLPipeline`].
+> [!NOTE]
+> The sd_embed library only supports Stable Diffusion, Stable Diffusion XL, Stable Diffusion 3, Stable Cascade, and Flux. Prompt weighting doesn't necessarily help for newer models like Flux which already has very good prompt adherence.
```py
-from diffusers import StableDiffusionXLPipeline, UniPCMultistepScheduler
-import torch
-
-pipe = StableDiffusionXLPipeline.from_pretrained("Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16)
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
-pipe.to("cuda")
+!uv pip install git+https://github.com/xhinker/sd_embed.git@main
```
-To upweight or downweight a concept, surround the text with parentheses. More parentheses applies a heavier weight on the text. You can also append a numerical multiplier to the text to indicate how much you want to increase or decrease its weights by.
+Format weighted text with numerical multipliers or parentheses. More parentheses mean stronger weighting.
| format | multiplier |
|---|---|
-| `(hippo)` | increase by 1.1x |
-| `((hippo))` | increase by 1.21x |
-| `(hippo:1.5)` | increase by 1.5x |
-| `(hippo:0.5)` | decrease by 4x |
-
-Create a prompt and use a combination of parentheses and numerical multipliers to upweight various text.
-
-```py
-from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl
-
-prompt = """A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
-"""
-
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
-```
-
-Use the `get_weighted_text_embeddings_sdxl` function to generate the prompt embeddings and the negative prompt embeddings. It'll also generated the pooled and negative pooled prompt embeddings since you're using the SDXL model.
+| `(cat)` | increase by 1.1x |
+| `((cat))` | increase by 1.21x |
+| `(cat:1.5)` | increase by 1.5x |
+| `(cat:0.5)` | decrease by 4x |
-> [!TIP]
-> You can safely ignore the error message below about the token index length exceeding the models maximum sequence length. All your tokens will be used in the embedding process.
->
-> ```
-> Token indices sequence length is longer than the specified maximum sequence length for this model
-> ```
-
-```py
-(
- prompt_embeds,
- prompt_neg_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds
-) = get_weighted_text_embeddings_sdxl(
- pipe,
- prompt=prompt,
- neg_prompt=neg_prompt
-)
-
-image = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=prompt_neg_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- num_inference_steps=30,
- height=1024,
- width=1024 + 512,
- guidance_scale=4.0,
- generator=torch.Generator("cuda").manual_seed(2)
-).images[0]
-image
-```
-
-
-
-
+Create a weighted prompt and pass it to [get_weighted_text_embeddings_sdxl](https://github.com/xhinker/sd_embed/blob/4a47f71150a22942fa606fb741a1c971d95ba56f/src/sd_embed/embedding_funcs.py#L405) to generate embeddings.
> [!TIP]
-> Refer to the [sd_embed](https://github.com/xhinker/sd_embed) repository for additional details about long prompt weighting for FLUX.1, Stable Cascade, and Stable Diffusion 1.5.
-
-### Textual inversion
-
-[Textual inversion](../training/text_inversion) is a technique for learning a specific concept from some images which you can use to generate new images conditioned on that concept.
-
-Create a pipeline and use the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] function to load the textual inversion embeddings (feel free to browse the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer) for 100+ trained concepts):
+> You could also pass negative prompts to `negative_prompt_embeds` and `negative_pooled_prompt_embeds`.
```py
import torch
-from diffusers import StableDiffusionPipeline
-
-pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
-).to("cuda")
-pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
-```
+from diffusers import DiffusionPipeline
+from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl
-Add the `
` text to the prompt to trigger the textual inversion.
+pipeline = DiffusionPipeline.from_pretrained(
+ "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.bfloat16, device_map="cuda"
+)
-```py
-from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
-
-prompt = """ A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
+prompt = """
+A (cute cat:1.4) lounges on a (floating leaf:1.2) in a (sparkling pool:1.1) during a peaceful summer afternoon.
+Gentle ripples reflect pastel skies, while (sunlight:1.1) casts soft highlights. The illustration is smooth and polished
+with elegant, sketchy lines and subtle gradients, evoking a ((whimsical, nostalgic, dreamy lofi atmosphere:2.0)),
+(anime-inspired:1.6), calming, comforting, and visually serene.
"""
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
+prompt_embeds, _, pooled_prompt_embeds, *_ = get_weighted_text_embeddings_sdxl(pipeline, prompt=prompt)
```
-Use the `get_weighted_text_embeddings_sd15` function to generate the prompt embeddings and the negative prompt embeddings.
+Pass the embeddings to `prompt_embeds` and `pooled_prompt_embeds` to generate your image.
```py
-(
- prompt_embeds,
- prompt_neg_embeds,
-) = get_weighted_text_embeddings_sd15(
- pipe,
- prompt=prompt,
- neg_prompt=neg_prompt
-)
-
-image = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=prompt_neg_embeds,
- height=768,
- width=896,
- guidance_scale=4.0,
- generator=torch.Generator("cuda").manual_seed(2)
-).images[0]
-image
+image = pipeline(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds).images[0]
```
-
+
-### DreamBooth
-
-[DreamBooth](../training/dreambooth) is a technique for generating contextualized images of a subject given just a few images of the subject to train on. It is similar to textual inversion, but DreamBooth trains the full model whereas textual inversion only fine-tunes the text embeddings. This means you should use [`~DiffusionPipeline.from_pretrained`] to load the DreamBooth model (feel free to browse the [Stable Diffusion Dreambooth Concepts Library](https://huggingface.co/sd-dreambooth-library) for 100+ trained models):
-
-```py
-import torch
-from diffusers import DiffusionPipeline, UniPCMultistepScheduler
-
-pipe = DiffusionPipeline.from_pretrained("sd-dreambooth-library/dndcoverart-v1", torch_dtype=torch.float16).to("cuda")
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
-```
-
-Depending on the model you use, you'll need to incorporate the model's unique identifier into your prompt. For example, the `dndcoverart-v1` model uses the identifier `dndcoverart`:
-
-```py
-from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
-
-prompt = """dndcoverart of A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
-"""
-
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
-
-(
- prompt_embeds
- , prompt_neg_embeds
-) = get_weighted_text_embeddings_sd15(
- pipe
- , prompt = prompt
- , neg_prompt = neg_prompt
-)
-```
-
-
-
-
+Prompt weighting works with [Textual inversion](./textual_inversion_inference) and [DreamBooth](./dreambooth) adapters too.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/write_own_pipeline.md b/docs/source/en/using-diffusers/write_own_pipeline.md
index 283397ff3e9d..e34727b5da25 100644
--- a/docs/source/en/using-diffusers/write_own_pipeline.md
+++ b/docs/source/en/using-diffusers/write_own_pipeline.md
@@ -1,4 +1,4 @@
-
-# Habana Gaudi에서 Stable Diffusion을 사용하는 방법
+# Intel Gaudi에서 Stable Diffusion을 사용하는 방법
🤗 Diffusers는 🤗 [Optimum Habana](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)를 통해서 Habana Gaudi와 호환됩니다.
diff --git a/docs/source/ko/optimization/mps.md b/docs/source/ko/optimization/mps.md
index c314cdcdfc57..004374c4af03 100644
--- a/docs/source/ko/optimization/mps.md
+++ b/docs/source/ko/optimization/mps.md
@@ -1,4 +1,4 @@
-
+
+[[open-in-colab]]
+
+# Desempenho básico
+
+Difusão é um processo aleatório que demanda muito processamento. Você pode precisar executar o [`DiffusionPipeline`] várias vezes antes de obter o resultado desejado. Por isso é importante equilibrar cuidadosamente a velocidade de geração e o uso de memória para iterar mais rápido.
+
+Este guia recomenda algumas dicas básicas de desempenho para usar o [`DiffusionPipeline`]. Consulte a seção de documentação sobre Otimização de Inferência, como [Acelerar inferência](./optimization/fp16) ou [Reduzir uso de memória](./optimization/memory) para guias de desempenho mais detalhados.
+
+## Uso de memória
+
+Reduzir a quantidade de memória usada indiretamente acelera a geração e pode ajudar um modelo a caber no dispositivo.
+
+O método [`~DiffusionPipeline.enable_model_cpu_offload`] move um modelo para a CPU quando não está em uso para economizar memória da GPU.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+pipeline.enable_model_cpu_offload()
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+print(f"Memória máxima reservada: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+```
+
+## Velocidade de inferência
+
+O processo de remoção de ruído é o mais exigente computacionalmente durante a difusão. Métodos que otimizam este processo aceleram a velocidade de inferência. Experimente os seguintes métodos para acelerar.
+
+- Adicione `device_map="cuda"` para colocar o pipeline em uma GPU. Colocar um modelo em um acelerador, como uma GPU, aumenta a velocidade porque realiza computações em paralelo.
+- Defina `torch_dtype=torch.bfloat16` para executar o pipeline em meia-precisão. Reduzir a precisão do tipo de dado aumenta a velocidade porque leva menos tempo para realizar computações em precisão mais baixa.
+
+```py
+import torch
+import time
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+```
+
+- Use um agendador mais rápido, como [`DPMSolverMultistepScheduler`], que requer apenas ~20-25 passos.
+- Defina `num_inference_steps` para um valor menor. Reduzir o número de passos de inferência reduz o número total de computações. No entanto, isso pode resultar em menor qualidade de geração.
+
+```py
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+
+start_time = time.perf_counter()
+image = pipeline(prompt).images[0]
+end_time = time.perf_counter()
+
+print(f"Geração de imagem levou {end_time - start_time:.3f} segundos")
+```
+
+## Qualidade de geração
+
+Muitos modelos de difusão modernos entregam imagens de alta qualidade imediatamente. No entanto, você ainda pode melhorar a qualidade de geração experimentando o seguinte.
+
+- Experimente um prompt mais detalhado e descritivo. Inclua detalhes como o meio da imagem, assunto, estilo e estética. Um prompt negativo também pode ajudar, guiando um modelo para longe de características indesejáveis usando palavras como baixa qualidade ou desfocado.
+
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+ )
+
+ prompt = """
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+ """
+ negative_prompt = "low quality, blurry, ugly, poor details"
+ pipeline(prompt, negative_prompt=negative_prompt).images[0]
+ ```
+
+ Para mais detalhes sobre como criar prompts melhores, consulte a documentação sobre [Técnicas de prompt](./using-diffusers/weighted_prompts).
+
+- Experimente um agendador diferente, como [`HeunDiscreteScheduler`] ou [`LMSDiscreteScheduler`], que sacrifica velocidade de geração por qualidade.
+
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline, HeunDiscreteScheduler
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+ )
+ pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
+
+ prompt = """
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+ """
+ negative_prompt = "low quality, blurry, ugly, poor details"
+ pipeline(prompt, negative_prompt=negative_prompt).images[0]
+ ```
+
+## Próximos passos
+
+Diffusers oferece otimizações mais avançadas e poderosas, como [group-offloading](./optimization/memory#group-offloading) e [compilação regional](./optimization/fp16#regional-compilation). Para saber mais sobre como maximizar o desempenho, consulte a seção sobre Otimização de Inferência.
diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml
index 6416c468a8e9..337d010fc74d 100644
--- a/docs/source/zh/_toctree.yml
+++ b/docs/source/zh/_toctree.yml
@@ -1,12 +1,150 @@
-- sections:
+- title: 开始Diffusers
+ sections:
- local: index
- title: 🧨 Diffusers
+ title: Diffusers
+ - local: installation
+ title: 安装
- local: quicktour
title: 快速入门
- local: stable_diffusion
title: 有效和高效的扩散
- - local: consisid
- title: 身份保持的文本到视频生成
- - local: installation
- title: 安装
- title: 开始
+
+- title: DiffusionPipeline
+ isExpanded: false
+ sections:
+ - local: using-diffusers/schedulers
+ title: Load schedulers and models
+
+- title: Inference
+ isExpanded: false
+ sections:
+ - local: training/distributed_inference
+ title: Distributed inference
+
+- title: Inference optimization
+ isExpanded: false
+ sections:
+ - local: optimization/fp16
+ title: Accelerate inference
+ - local: optimization/cache
+ title: Caching
+ - local: optimization/memory
+ title: Reduce memory usage
+ - local: optimization/speed-memory-optims
+ title: Compile and offloading quantized models
+ - title: Community optimizations
+ sections:
+ - local: optimization/pruna
+ title: Pruna
+ - local: optimization/xformers
+ title: xFormers
+ - local: optimization/tome
+ title: Token merging
+ - local: optimization/deepcache
+ title: DeepCache
+ - local: optimization/tgate
+ title: TGATE
+ - local: optimization/xdit
+ title: xDiT
+ - local: optimization/para_attn
+ title: ParaAttention
+
+- title: Hybrid Inference
+ isExpanded: false
+ sections:
+ - local: hybrid_inference/overview
+ title: Overview
+ - local: hybrid_inference/vae_encode
+ title: VAE Encode
+ - local: hybrid_inference/api_reference
+ title: API Reference
+
+- title: Modular Diffusers
+ isExpanded: false
+ sections:
+ - local: modular_diffusers/overview
+ title: Overview
+ - local: modular_diffusers/quickstart
+ title: Quickstart
+ - local: modular_diffusers/modular_diffusers_states
+ title: States
+ - local: modular_diffusers/pipeline_block
+ title: ModularPipelineBlocks
+ - local: modular_diffusers/sequential_pipeline_blocks
+ title: SequentialPipelineBlocks
+ - local: modular_diffusers/loop_sequential_pipeline_blocks
+ title: LoopSequentialPipelineBlocks
+ - local: modular_diffusers/auto_pipeline_blocks
+ title: AutoPipelineBlocks
+ - local: modular_diffusers/modular_pipeline
+ title: ModularPipeline
+ - local: modular_diffusers/components_manager
+ title: ComponentsManager
+ - local: modular_diffusers/guiders
+ title: Guiders
+
+- title: Training
+ isExpanded: false
+ sections:
+ - local: training/overview
+ title: Overview
+ - local: training/adapt_a_model
+ title: Adapt a model to a new task
+ - title: Models
+ sections:
+ - local: training/text2image
+ title: Text-to-image
+ - local: training/kandinsky
+ title: Kandinsky 2.2
+ - local: training/wuerstchen
+ title: Wuerstchen
+ - local: training/controlnet
+ title: ControlNet
+ - local: training/instructpix2pix
+ title: InstructPix2Pix
+ - title: Methods
+ sections:
+ - local: training/text_inversion
+ title: Textual Inversion
+ - local: training/dreambooth
+ title: DreamBooth
+ - local: training/lora
+ title: LoRA
+
+- title: Model accelerators and hardware
+ isExpanded: false
+ sections:
+ - local: optimization/onnx
+ title: ONNX
+ - local: optimization/open_vino
+ title: OpenVINO
+ - local: optimization/coreml
+ title: Core ML
+ - local: optimization/mps
+ title: Metal Performance Shaders (MPS)
+ - local: optimization/habana
+ title: Intel Gaudi
+ - local: optimization/neuron
+ title: AWS Neuron
+
+- title: Specific pipeline examples
+ isExpanded: false
+ sections:
+ - local: using-diffusers/consisid
+ title: ConsisID
+
+- title: Resources
+ isExpanded: false
+ sections:
+ - title: Task recipes
+ sections:
+ - local: community_projects
+ title: Projects built with Diffusers
+ - local: conceptual/philosophy
+ title: Philosophy
+ - local: conceptual/contribution
+ title: How to contribute?
+ - local: conceptual/ethical_guidelines
+ title: Diffusers' Ethical Guidelines
+ - local: conceptual/evaluation
+ title: Evaluating Diffusion Models
diff --git a/docs/source/zh/community_projects.md b/docs/source/zh/community_projects.md
new file mode 100644
index 000000000000..0440142452f1
--- /dev/null
+++ b/docs/source/zh/community_projects.md
@@ -0,0 +1,89 @@
+
+
+# 社区项目
+
+欢迎来到社区项目。这个空间致力于展示我们充满活力的社区使用`diffusers`库创建的令人难以置信的工作和创新应用。
+
+本节旨在:
+
+- 突出使用`diffusers`构建的多样化和鼓舞人心的项目
+- 促进我们社区内的知识共享
+- 提供如何利用`diffusers`的实际例子
+
+探索愉快,感谢您成为Diffusers社区的一部分!
+
+
diff --git a/docs/source/zh/conceptual/contribution.md b/docs/source/zh/conceptual/contribution.md
new file mode 100644
index 000000000000..0f9743882523
--- /dev/null
+++ b/docs/source/zh/conceptual/contribution.md
@@ -0,0 +1,485 @@
+
+
+# 如何为Diffusers 🧨做贡献
+
+我们❤️来自开源社区的贡献!欢迎所有人参与,所有类型的贡献——不仅仅是代码——都受到重视和赞赏。回答问题、帮助他人、主动交流以及改进文档对社区都极具价值,所以如果您愿意参与,请不要犹豫!
+
+我们鼓励每个人先在公开Discord频道里打招呼👋。在那里我们讨论扩散模型的最新趋势、提出问题、展示个人项目、互相协助贡献,或者只是闲聊☕。
+
+无论您选择以何种方式贡献,我们都致力于成为一个开放、友好、善良的社区。请阅读我们的[行为准则](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md),并在互动时注意遵守。我们也建议您了解指导本项目的[伦理准则](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines),并请您遵循同样的透明度和责任原则。
+
+我们高度重视社区的反馈,所以如果您认为自己有能帮助改进库的有价值反馈,请不要犹豫说出来——每条消息、评论、issue和拉取请求(PR)都会被阅读和考虑。
+
+## 概述
+
+您可以通过多种方式做出贡献,从在issue和讨论区回答问题,到向核心库添加新的diffusion模型。
+
+下面我们按难度升序列出不同的贡献方式,所有方式对社区都很有价值:
+
+* 1. 在[Diffusers讨论论坛](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers)或[Discord](https://discord.gg/G7tWnz98XR)上提问和回答问题
+* 2. 在[GitHub Issues标签页](https://github.com/huggingface/diffusers/issues/new/choose)提交新issue,或在[GitHub Discussions标签页](https://github.com/huggingface/diffusers/discussions/new/choose)发起新讨论
+* 3. 在[GitHub Issues标签页](https://github.com/huggingface/diffusers/issues)解答issue,或在[GitHub Discussions标签页](https://github.com/huggingface/diffusers/discussions)参与讨论
+* 4. 解决标记为"Good first issue"的简单问题,详见[此处](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)
+* 5. 参与[文档](https://github.com/huggingface/diffusers/tree/main/docs/source)建设
+* 6. 贡献[社区Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples)
+* 7. 完善[示例代码](https://github.com/huggingface/diffusers/tree/main/examples)
+* 8. 解决标记为"Good second issue"的中等难度问题,详见[此处](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22)
+* 9. 添加新pipeline/模型/调度器,参见["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)和["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)类issue。此类贡献请先阅读[设计哲学](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md)
+
+重申:**所有贡献对社区都具有重要价值。**下文将详细说明各类贡献方式。
+
+对于4-9类贡献,您需要提交PR(拉取请求),具体操作详见[如何提交PR](#how-to-open-a-pr)章节。
+
+### 1. 在Diffusers讨论区或Discord提问与解答
+
+任何与Diffusers库相关的问题或讨论都可以发布在[官方论坛](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/)或[Discord频道](https://discord.gg/G7tWnz98XR),包括但不限于:
+- 分享训练/推理实验报告
+- 展示个人项目
+- 咨询非官方训练示例
+- 项目提案
+- 通用反馈
+- 论文解读
+- 基于Diffusers库的个人项目求助
+- 一般性问题
+- 关于diffusion模型的伦理讨论
+- ...
+
+论坛/Discord上的每个问题都能促使社区公开分享知识,很可能帮助未来遇到相同问题的初学者。请务必提出您的疑问。
+同样地,通过回答问题您也在为社区创造公共知识文档,这种贡献极具价值。
+
+**请注意**:提问/回答时投入的精力越多,产生的公共知识质量就越高。精心构建的问题与专业解答能形成高质量知识库,而表述不清的问题则可能降低讨论价值。
+
+低质量的问题或回答会降低公共知识库的整体质量。
+简而言之,高质量的问题或回答应具备*精确性*、*简洁性*、*相关性*、*易于理解*、*可访问性*和*格式规范/表述清晰*等特质。更多详情请参阅[如何提交优质议题](#how-to-write-a-good-issue)章节。
+
+**关于渠道的说明**:
+[*论坛*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)的内容能被谷歌等搜索引擎更好地收录,且帖子按热度而非时间排序,便于查找历史问答。此外,论坛内容更容易被直接链接引用。
+而*Discord*采用即时聊天模式,适合快速交流。虽然在Discord上可能更快获得解答,但信息会随时间淹没,且难以回溯历史讨论。因此我们强烈建议在论坛发布优质问答,以构建可持续的社区知识库。若Discord讨论产生有价值结论,建议将成果整理发布至论坛以惠及更多读者。
+
+### 2. 在GitHub议题页提交新议题
+
+🧨 Diffusers库的稳健性离不开用户的问题反馈,感谢您的报错。
+
+请注意:GitHub议题仅限处理与Diffusers库代码直接相关的技术问题、错误报告、功能请求或库设计反馈。
+简言之,**与Diffusers库代码(含文档)无关**的内容应发布至[论坛](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)或[Discord](https://discord.gg/G7tWnz98XR)。
+
+**提交新议题时请遵循以下准则**:
+- 确认是否已有类似议题(使用GitHub议题页的搜索栏)
+- 请勿在现有议题下追加新问题。若存在高度关联议题,应新建议题并添加相关链接
+- 确保使用英文提交。非英语用户可通过[DeepL](https://www.deepl.com/translator)等免费工具翻译
+- 检查升级至最新Diffusers版本是否能解决问题。提交前请确认`python -c "import diffusers; print(diffusers.__version__)"`显示的版本号不低于最新版本
+- 记请记住,你在提交新issue时投入的精力越多,得到的回答质量就越高,Diffusers项目的整体issue质量也会越好。
+
+新issue通常包含以下内容:
+
+#### 2.1 可复现的最小化错误报告
+
+错误报告应始终包含可复现的代码片段,并尽可能简洁明了。具体而言:
+- 尽量缩小问题范围,**不要直接粘贴整个代码文件**
+- 规范代码格式
+- 除Diffusers依赖库外,不要包含其他外部库
+- **务必**提供环境信息:可在终端运行`diffusers-cli env`命令,然后将显示的信息复制到issue中
+- 详细说明问题。如果读者不清楚问题所在及其影响,就无法解决问题
+- **确保**读者能以最小成本复现问题。如果代码片段因缺少库或未定义变量而无法运行,读者将无法提供帮助。请确保提供的可复现代码尽可能精简,可直接复制到Python shell运行
+- 如需特定模型/数据集复现问题,请确保读者能获取这些资源。可将模型/数据集上传至[Hub](https://huggingface.co)便于下载。尽量保持模型和数据集体积最小化,降低复现难度
+
+更多信息请参阅[如何撰写优质issue](#how-to-write-a-good-issue)章节。
+
+提交错误报告请点击[此处](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml)。
+
+#### 2.2 功能请求
+
+优质的功能请求应包含以下要素:
+
+1. 首先说明动机:
+* 是否与库的使用痛点相关?若是,请解释原因,最好提供演示问题的代码片段
+* 是否因项目需求产生?我们很乐意了解详情!
+* 是否是你已实现且认为对社区有价值的功能?请说明它为你解决了什么问题
+2. 用**完整段落**描述功能特性
+3. 提供**代码片段**演示预期用法
+4. 如涉及论文,请附上链接
+5. 可补充任何有助于理解的辅助材料(示意图、截图等)
+
+提交功能请求请点击[此处](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=)。
+
+#### 2.3 设计反馈
+
+关于库设计的反馈(无论正面还是负面)能极大帮助核心维护者打造更友好的库。要了解当前设计理念,请参阅[此文档](https://huggingface.co/docs/diffusers/conceptual/philosophy)如果您认为某个设计选择与当前理念不符,请说明原因及改进建议。如果某个设计选择因过度遵循理念而限制了使用场景,也请解释原因并提出调整方案。
+若某个设计对您特别实用,请同样留下备注——这对未来的设计决策极具参考价值。
+
+您可通过[此链接](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=)提交设计反馈。
+
+#### 2.4 技术问题
+
+技术问题主要涉及库代码的实现逻辑或特定功能模块的作用。提问时请务必:
+- 附上相关代码链接
+- 详细说明难以理解的具体原因
+
+技术问题提交入口:[点击此处](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml)
+
+#### 2.5 新模型/调度器/pipeline提案
+
+若diffusion模型社区发布了您希望集成到Diffusers库的新模型、pipeline或调度器,请提供以下信息:
+* 简要说明并附论文或发布链接
+* 开源实现链接(如有)
+* 模型权重下载链接(如已公开)
+
+若您愿意参与开发,请告知我们以便指导。另请尝试通过GitHub账号标记原始组件作者。
+
+提案提交地址:[新建请求](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml)
+
+### 3. 解答GitHub问题
+
+回答GitHub问题可能需要Diffusers的技术知识,但我们鼓励所有人尝试参与——即使您对答案不完全正确。高质量回答的建议:
+- 保持简洁精炼
+- 严格聚焦问题本身
+- 提供代码/论文等佐证材料
+- 优先用代码说话:若代码片段能解决问题,请提供完整可复现代码
+
+许多问题可能存在离题、重复或无关情况。您可以通过以下方式协助维护者:
+- 引导提问者精确描述问题
+- 标记重复issue并附原链接
+- 推荐用户至[论坛](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)或[Discord](https://discord.gg/G7tWnz98XR)
+
+在确认提交的Bug报告正确且需要修改源代码后,请继续阅读以下章节内容。
+
+以下所有贡献都需要提交PR(拉取请求)。具体操作步骤详见[如何提交PR](#how-to-open-a-pr)章节。
+
+### 4. 修复"Good first issue"类问题
+
+标有[Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)标签的问题通常已说明解决方案建议,便于修复。若该问题尚未关闭且您想尝试解决,只需留言"我想尝试解决这个问题"。通常有三种情况:
+- a.) 问题描述已提出解决方案。若您认可该方案,可直接提交PR或草稿PR进行修复
+- b.) 问题描述未提出解决方案。您可询问修复建议,Diffusers团队会尽快回复。若有成熟解决方案,也可直接提交PR
+- c.) 已有PR但问题未关闭。若原PR停滞,可新开PR并关联原PR(开源社区常见现象)。若PR仍活跃,您可通过建议、审查或协作等方式帮助原作者
+
+### 5. 文档贡献
+
+优秀库**必然**拥有优秀文档!官方文档是新用户的首要接触点,因此文档贡献具有**极高价值**。贡献形式包括:
+- 修正拼写/语法错误
+- 修复文档字符串格式错误(如显示异常或链接失效)
+- 修正文档字符串中张量的形状/维度描述
+- 优化晦涩或错误的说明
+- 更新过时代码示例
+- 文档翻译
+
+[官方文档页面](https://huggingface.co/docs/diffusers/index)所有内容均属可修改范围,对应[文档源文件](https://github.com/huggingface/diffusers/tree/main/docs/source)可进行编辑。修改前请查阅[验证说明](https://github.com/huggingface/diffusers/tree/main/docs)。
+
+### 6. 贡献社区流程
+
+> [!TIP]
+> 阅读[社区流程](../using-diffusers/custom_pipeline_overview#community-pipelines)指南了解GitHub与Hugging Face Hub社区流程的区别。若想了解我们设立社区流程的原因,请查看GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841)(简而言之,我们无法维护diffusion模型所有可能的推理使用方式,但也不希望限制社区构建这些流程)。
+
+贡献社区流程是向社区分享创意与成果的绝佳方式。您可以在[`DiffusionPipeline`]基础上构建流程,任何人都能通过设置`custom_pipeline`参数加载使用。本节将指导您创建一个简单的"单步"流程——UNet仅执行单次前向传播并调用调度器一次。
+
+1. 为社区流程创建one_step_unet.py文件。只要用户已安装相关包,该文件可包含任意所需包。确保仅有一个继承自[`DiffusionPipeline`]的流程类,用于从Hub加载模型权重和调度器配置。在`__init__`函数中添加UNet和调度器。
+
+ 同时添加`register_modules`函数,确保您的流程及其组件可通过[`~DiffusionPipeline.save_pretrained`]保存。
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+```
+
+2. 在前向传播中(建议定义为`__call__`),可添加任意功能。对于"单步"流程,创建随机图像并通过设置`timestep=1`调用UNet和调度器一次。
+
+```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
+ def __init__(self, unet, scheduler):
+ super().__init__()
+
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ def __call__(self):
+ image = torch.randn(
+ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
+ )
+ timestep = 1
+
+ model_output = self.unet(image, timestep).sample
+ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample
+
+ return scheduler_output
+```
+
+现在您可以通过传入UNet和调度器来运行流程,若流程结构相同也可加载预训练权重。
+
+```python
+from diffusers import DDPMScheduler, UNet2DModel
+
+scheduler = DDPMScheduler()
+unet = UNet2DModel()
+
+pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)
+output = pipeline()
+# 加载预训练权重
+pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
+output = pipeline()
+```
+
+您可以选择将pipeline作为GitHub社区pipeline或Hub社区pipeline进行分享。
+
+
+
+
+通过向Diffusers[代码库](https://github.com/huggingface/diffusers)提交拉取请求来分享GitHub pipeline,将one_step_unet.py文件添加到[examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community)子文件夹中。
+
+
+
+
+通过在Hub上创建模型仓库并上传one_step_unet.py文件来分享Hub pipeline。
+
+
+
+
+### 7. 贡献训练示例
+
+Diffusers训练示例是位于[examples](https://github.com/huggingface/diffusers/tree/main/examples)目录下的训练脚本集合。
+
+我们支持两种类型的训练示例:
+
+- 官方训练示例
+- 研究型训练示例
+
+研究型训练示例位于[examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects),而官方训练示例包含[examples](https://github.com/huggingface/diffusers/tree/main/examples)目录下除`research_projects`和`community`外的所有文件夹。
+官方训练示例由Diffusers核心维护者维护,研究型训练示例则由社区维护。
+这与[6. 贡献社区pipeline](#6-contribute-a-community-pipeline)中关于官方pipeline与社区pipeline的原因相同:核心维护者不可能维护diffusion模型的所有可能训练方法。
+如果Diffusers核心维护者和社区认为某种训练范式过于实验性或不够普及,相应训练代码应放入`research_projects`文件夹并由作者维护。
+
+官方训练和研究型示例都包含一个目录,其中含有一个或多个训练脚本、`requirements.txt`文件和`README.md`文件。用户使用时需要先克隆代码库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+```
+
+并安装训练所需的所有额外依赖:
+
+```bash
+cd diffusers
+pip install -r examples//requirements.txt
+```
+
+因此添加示例时,`requirements.txt`文件应定义训练示例所需的所有pip依赖项,安装完成后用户即可运行示例训练脚本。可参考[DreamBooth的requirements.txt文件](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt)。
+- 运行示例所需的所有代码应集中在单个Python文件中
+- 用户应能通过命令行`python .py --args`直接运行示例
+- **示例**应保持简洁,主要展示如何使用Diffusers进行训练。示例脚本的目的**不是**创建最先进的diffusion模型,而是复现已知训练方案,避免添加过多自定义逻辑。因此,这些示例也力求成为优质的教学材料。
+
+提交示例时,强烈建议参考现有示例(如[dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py))来了解规范格式。
+我们强烈建议贡献者使用[Accelerate库](https://github.com/huggingface/accelerate),因其与Diffusers深度集成。
+当示例脚本完成后,请确保添加详细的`README.md`说明使用方法,包括:
+- 运行示例的具体命令(示例参见[此处](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch))
+- 训练结果链接(日志/模型等),展示用户可预期的效果(示例参见[此处](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5))
+- 若添加非官方/研究性训练示例,**必须注明**维护者信息(含Git账号),格式参照[此处](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations)
+
+贡献官方训练示例时,还需在对应目录添加测试文件(如[examples/dreambooth/test_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/test_dreambooth.py)),非官方示例无需此步骤。
+
+### 8. 处理"Good second issue"类问题
+
+标有[Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22)标签的问题通常比[Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)更复杂。
+这类问题的描述通常不会提供详细解决指引,需要贡献者对库有较深理解。
+若您想解决此类问题,可直接提交PR并关联对应issue。若已有未合并的PR,请分析原因后提交改进版。需注意,Good second issue类PR的合并难度通常高于good first issues。在需要帮助的时候请不要犹豫,大胆的向核心维护者询问。
+
+### 9. 添加管道、模型和调度器
+
+管道(pipelines)、模型(models)和调度器(schedulers)是Diffusers库中最重要的组成部分。它们提供了对最先进diffusion技术的便捷访问,使得社区能够构建强大的生成式AI应用。
+
+通过添加新的模型、管道或调度器,您可能为依赖Diffusers的任何用户界面开启全新的强大用例,这对整个生成式AI生态系统具有巨大价值。
+
+Diffusers针对这三类组件都有一些开放的功能请求——如果您还不确定要添加哪个具体组件,可以浏览以下链接:
+- [模型或管道](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
+- [调度器](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
+
+在添加任何组件之前,强烈建议您阅读[设计哲学指南](philosophy),以更好地理解这三类组件的设计理念。请注意,如果添加的模型、调度器或管道与我们的设计理念存在严重分歧,我们将无法合并,因为这会导致API不一致。如果您从根本上不同意某个设计选择,请改为提交[反馈问题](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=),以便讨论是否应该更改库中的特定设计模式/选择,以及是否更新我们的设计哲学。保持库内的一致性对我们非常重要。
+
+请确保在PR中添加原始代码库/论文的链接,并最好直接在PR中@原始作者,以便他们可以跟踪进展并在有疑问时提供帮助。
+
+如果您在PR过程中遇到不确定或卡住的情况,请随时留言请求初步审查或帮助。
+
+#### 复制机制(Copied from)
+
+在添加任何管道、模型或调度器代码时,理解`# Copied from`机制是独特且重要的。您会在整个Diffusers代码库中看到这种机制,我们使用它的原因是为了保持代码库易于理解和维护。用`# Copied from`机制标记代码会强制标记的代码与复制来源的代码完全相同。这使得每当您运行`make fix-copies`时,可以轻松更新并将更改传播到多个文件。
+
+例如,在下面的代码示例中,[`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`]是原始代码,而`AltDiffusionPipelineOutput`使用`# Copied from`机制来复制它。唯一的区别是将类前缀从`Stable`改为`Alt`。
+
+```py
+# 从 diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput 复制并将 Stable 替换为 Alt
+class AltDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Alt Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+ num_channels)`.
+ nsfw_content_detected (`List[bool]`)
+ List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
+ `None` if safety checking could not be performed.
+ """
+```
+
+要了解更多信息,请阅读[~不要~重复自己*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static)博客文章的相应部分。
+
+## 如何撰写优质问题
+
+**问题描述越清晰,被快速解决的可能性就越高。**
+
+1. 确保使用了正确的issue模板。您可以选择*错误报告*、*功能请求*、*API设计反馈*、*新模型/流水线/调度器添加*、*论坛*或空白issue。在[新建issue](https://github.com/huggingface/diffusers/issues/new/choose)时务必选择正确的模板。
+2. **精确描述**:为issue起一个恰当的标题。尽量用最简练的语言描述问题。提交issue时越精确,理解问题和潜在解决方案所需的时间就越少。确保一个issue只针对一个问题,不要将多个问题放在同一个issue中。如果发现多个问题,请分别创建多个issue。如果是错误报告,请尽可能精确描述错误类型——不应只写"diffusers出错"。
+3. **可复现性**:无法复现的代码片段 == 无法解决问题。如果遇到错误,维护人员必须能够**复现**它。确保包含一个可以复制粘贴到Python解释器中复现问题的代码片段。确保您的代码片段是可运行的,即没有缺少导入或图像链接等问题。issue应包含错误信息和可直接复制粘贴以复现相同错误的代码片段。如果issue涉及本地模型权重或无法被读者访问的本地数据,则问题无法解决。如果无法共享数据或模型,请尝试创建虚拟模型或虚拟数据。
+4. **最小化原则**:通过尽可能简洁的描述帮助读者快速理解问题。删除所有与问题无关的代码/信息。如果发现错误,请创建最简单的代码示例来演示问题,不要一发现错误就把整个工作流程都转储到issue中。例如,如果在训练模型时某个阶段出现错误或训练过程中遇到问题时,应首先尝试理解训练代码的哪部分导致了错误,并用少量代码尝试复现。建议使用模拟数据替代完整数据集进行测试。
+5. 添加引用链接。当提及特定命名、方法或模型时,请务必提供引用链接以便读者理解。若涉及具体PR或issue,请确保添加对应链接。不要假设读者了解你所指内容。issue中引用链接越丰富越好。
+6. 规范格式。请确保规范格式化issue内容:Python代码使用代码语法块,错误信息使用标准代码语法。详见[GitHub官方格式文档](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax)。
+7. 请将issue视为百科全书的精美词条,而非待解决的工单。每个规范撰写的issue不仅是向维护者有效传递问题的方式,更是帮助社区深入理解库特性的公共知识贡献。
+
+## 优质PR编写规范
+
+1. 保持风格统一。理解现有设计模式和语法规范,确保新增代码与代码库现有结构无缝衔接。显著偏离现有设计模式或用户界面的PR将不予合并。
+2. 聚焦单一问题。每个PR应当只解决一个明确问题,避免"顺手修复其他问题"的陷阱。包含多个无关修改的PR会极大增加审查难度。
+3. 如适用,建议添加代码片段演示新增功能的使用方法。
+4. PR标题应准确概括其核心贡献。
+5. 若PR针对某个issue,请在描述中注明issue编号以建立关联(也让关注该issue的用户知晓有人正在处理);
+6. 进行中的PR请在标题添加`[WIP]`前缀。这既能避免重复劳动,也可与待合并PR明确区分;
+7. 文本表述与格式要求请参照[优质issue编写规范](#how-to-write-a-good-issue);
+8. 确保现有测试用例全部通过;
+9. 必须添加高覆盖率测试。未经充分测试的代码不予合并。
+- 若新增`@slow`测试,请使用`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`确保通过。
+CircleCI不执行慢速测试,但GitHub Actions会每日夜间运行!
+10. 所有公开方法必须包含格式规范、兼容markdown的说明文档。可参考[`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py)
+11. 由于代码库快速增长,必须确保不会添加明显增加仓库体积的文件(如图片、视频等非文本文件)。建议优先使用托管在hf.co的`dataset`(例如[`hf-internal-testing`](https://huggingface.co/hf-internal-testing)或[huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images))存放这类文件。若为外部贡献,可将图片添加到PR中并请Hugging Face成员将其迁移至该数据集。
+
+## 提交PR流程
+
+编写代码前,强烈建议先搜索现有PR或issue,确认没有重复工作。如有疑问,建议先创建issue获取反馈。
+
+贡献至🧨 Diffusers需要基本的`git`技能。虽然`git`学习曲线较高,但其拥有最完善的手册。在终端输入`git --help`即可查阅,或参考书籍[Pro Git](https://git-scm.com/book/en/v2)。
+
+请按以下步骤操作([支持的Python版本](https://github.com/huggingface/diffusers/blob/83bc6c94eaeb6f7704a2a428931cf2d9ad973ae9/setup.py#L270)):
+
+1. 在[仓库页面](https://github.com/huggingface/diffusers)点击"Fork"按钮创建代码副本至您的GitHub账户
+
+2. 克隆fork到本地,并添加主仓库为远程源:
+ ```bash
+ $ git clone git@github.com:<您的GitHub账号>/diffusers.git
+ $ cd diffusers
+ $ git remote add upstream https://github.com/huggingface/diffusers.git
+ ```
+
+3. 创建新分支进行开发:
+ ```bash
+ $ git checkout -b 您的开发分支名称
+ ```
+**禁止**直接在`main`分支上修改
+
+4. 在虚拟环境中运行以下命令配置开发环境:
+ ```bash
+ $ pip install -e ".[dev]"
+ ```
+若已克隆仓库,可能需要先执行`git pull`获取最新代码
+
+5. 在您的分支上开发功能
+
+开发过程中应确保测试通过。可运行受影响测试:
+ ```bash
+ $ pytest tests/<待测文件>.py
+ ```
+执行测试前请安装测试依赖:
+ ```bash
+ $ pip install -e ".[test]"
+ ```
+也可运行完整测试套件(需高性能机器):
+ ```bash
+ $ make test
+ ```
+
+🧨 Diffusers使用`black`和`isort`工具保持代码风格统一。修改后请执行自动化格式校正与代码验证,以下内容无法通过以下命令一次性自动化完成:
+
+```bash
+$ make style
+```
+
+🧨 Diffusers 还使用 `ruff` 和一些自定义脚本来检查代码错误。虽然质量控制流程会在 CI 中运行,但您也可以通过以下命令手动执行相同的检查:
+
+```bash
+$ make quality
+```
+
+当您对修改满意后,使用 `git add` 添加更改的文件,并通过 `git commit` 在本地记录这些更改:
+
+```bash
+$ git add modified_file.py
+$ git commit -m "关于您所做更改的描述性信息。"
+```
+
+定期将您的代码副本与原始仓库同步是一个好习惯。这样可以快速适应上游变更:
+
+```bash
+$ git pull upstream main
+```
+
+使用以下命令将更改推送到您的账户:
+
+```bash
+$ git push -u origin 此处替换为您的描述性分支名称
+```
+
+6. 确认无误后,请访问您 GitHub 账户中的派生仓库页面。点击「Pull request」将您的更改提交给项目维护者审核。
+
+7. 如果维护者要求修改,这很正常——核心贡献者也会遇到这种情况!为了让所有人能在 Pull request 中看到变更,请在本地分支继续工作并将修改推送到您的派生仓库,这些变更会自动出现在 Pull request 中。
+
+### 测试
+
+我们提供了全面的测试套件来验证库行为和多个示例。库测试位于 [tests 文件夹](https://github.com/huggingface/diffusers/tree/main/tests)。
+
+我们推荐使用 `pytest` 和 `pytest-xdist`,因为它们速度更快。在仓库根目录下运行以下命令执行库测试:
+
+```bash
+$ python -m pytest -n auto --dist=loadfile -s -v ./tests/
+```
+
+实际上,这就是 `make test` 的实现方式!
+
+您可以指定更小的测试范围来仅验证您正在开发的功能。
+
+默认情况下会跳过耗时测试。设置 `RUN_SLOW` 环境变量为 `yes` 可运行这些测试。注意:这将下载数十 GB 的模型文件——请确保您有足够的磁盘空间、良好的网络连接或充足的耐心!
+
+```bash
+$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
+```
+
+我们也完全支持 `unittest`,运行方式如下:
+
+```bash
+$ python -m unittest discover -s tests -t . -v
+$ python -m unittest discover -s examples -t examples -v
+```
+
+### 将派生仓库的 main 分支与上游(HuggingFace)main 分支同步
+
+为避免向上游仓库发送引用通知(这会给相关 PR 添加注释并向开发者发送不必要的通知),在同步派生仓库的 main 分支时,请遵循以下步骤:
+1. 尽可能避免通过派生仓库的分支和 PR 来同步上游,而是直接合并到派生仓库的 main 分支
+2. 如果必须使用 PR,请在检出分支后执行以下操作:
+```bash
+$ git checkout -b 您的同步分支名称
+$ git pull --squash --no-commit upstream main
+$ git commit -m '提交信息(不要包含 GitHub 引用)'
+$ git push --set-upstream origin 您的分支名称
+```
+
+### 风格指南
+
+对于文档字符串,🧨 Diffusers 遵循 [Google 风格指南](https://google.github.io/styleguide/pyguide.html)。
diff --git a/docs/source/zh/conceptual/ethical_guidelines.md b/docs/source/zh/conceptual/ethical_guidelines.md
new file mode 100644
index 000000000000..535cc86e5f0c
--- /dev/null
+++ b/docs/source/zh/conceptual/ethical_guidelines.md
@@ -0,0 +1,56 @@
+
+
+# 🧨 Diffusers伦理准则
+
+## 前言
+
+[Diffusers](https://huggingface.co/docs/diffusers/index)不仅提供预训练的diffusion模型,还是一个模块化工具箱,支持推理和训练功能。
+
+鉴于该技术在实际场景中的应用及其可能对社会产生的负面影响,我们认为有必要制定项目伦理准则,以指导Diffusers库的开发、用户贡献和使用规范。
+
+该技术涉及的风险仍在持续评估中,主要包括但不限于:艺术家版权问题、深度伪造滥用、不当情境下的色情内容生成、非自愿的人物模仿、以及加剧边缘群体压迫的有害社会偏见。我们将持续追踪风险,并根据社区反馈动态调整本准则。
+
+## 适用范围
+
+Diffusers社区将在项目开发中贯彻以下伦理准则,并协调社区贡献的整合方式,特别是在涉及伦理敏感议题的技术决策时。
+
+## 伦理准则
+
+以下准则具有普遍适用性,但我们主要在处理涉及伦理敏感问题的技术决策时实施。同时,我们承诺将根据技术发展带来的新兴风险持续调整这些原则:
+
+- **透明度**:我们承诺以透明方式管理PR(拉取请求),向用户解释决策依据,并公开技术选择过程。
+
+- **一致性**:我们承诺为用户提供统一标准的项目管理,保持技术稳定性和连贯性。
+
+- **简洁性**:为了让Diffusers库更易使用和开发,我们承诺保持项目目标精简且逻辑自洽。
+
+- **可及性**:本项目致力于降低贡献门槛,即使非技术人员也能参与运营,从而使研究资源更广泛地服务于社区。
+
+- **可复现性**:对于通过Diffusers库发布的上游代码、模型和数据集,我们将明确说明其可复现性。
+
+- **责任性**:作为社区和团队,我们共同承担用户责任,通过风险预判和缓解措施来应对技术潜在危害。
+
+## 实施案例:安全功能与机制
+
+团队持续开发技术和非技术工具,以应对diffusion技术相关的伦理与社会风险。社区反馈对于功能实施和风险意识提升具有不可替代的价值:
+
+- [**社区讨论区**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions):促进社区成员就项目开展协作讨论。
+
+- **偏见探索与评估**:Hugging Face团队提供[交互空间](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)展示Stable Diffusion中的偏见。我们支持并鼓励此类偏见探索与评估工作。
+
+- **部署安全强化**:
+
+ - [**Safe Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe):解决Stable Diffusion等基于未过滤网络爬取数据训练的模型容易产生不当内容的问题。相关论文:[Safe Latent Diffusion:缓解diffusion模型中的不当退化](https://huggingface.co/papers/2211.05105)。
+
+ - [**安全检测器**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py):通过比对图像生成后嵌入空间中硬编码有害概念集的类别概率进行检测。有害概念列表经特殊处理以防逆向工程。
+
+- **分阶段模型发布**:对于高度敏感的仓库,采用分级访问控制。这种阶段性发布机制让作者能更好地管控使用场景。
+
+- **许可证制度**:采用新型[OpenRAILs](https://huggingface.co/blog/open_rail)许可协议,在保障开放访问的同时设置使用限制以确保更负责任的应用。
diff --git a/docs/source/zh/conceptual/evaluation.md b/docs/source/zh/conceptual/evaluation.md
new file mode 100644
index 000000000000..770d197be041
--- /dev/null
+++ b/docs/source/zh/conceptual/evaluation.md
@@ -0,0 +1,546 @@
+
+
+# Diffusion模型评估指南
+
+
+
+
+
+> [!TIP]
+> 鉴于当前已出现针对图像生成Diffusion模型的成熟评估框架(如[HEIM](https://crfm.stanford.edu/helm/heim/latest/)、[T2I-Compbench](https://huggingface.co/papers/2307.06350)、[GenEval](https://huggingface.co/papers/2310.11513)),本文档部分内容已过时。
+
+像 [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) 这类生成模型的评估本质上是主观的。但作为开发者和研究者,我们经常需要在众多可能性中做出审慎选择。那么当面对不同生成模型(如 GANs、Diffusion 等)时,该如何决策?
+
+定性评估容易产生偏差,可能导致错误结论;而定量指标又未必能准确反映图像质量。因此,通常需要结合定性与定量评估来获得更可靠的模型选择依据。
+
+本文档将系统介绍扩散模型的定性与定量评估方法(非穷尽列举)。对于定量方法,我们将重点演示如何结合 `diffusers` 库实现这些评估。
+
+文档所示方法同样适用于评估不同[噪声调度器](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview)在固定生成模型下的表现差异。
+
+## 评估场景
+
+我们涵盖以下Diffusion模型管线的评估:
+
+- 文本引导图像生成(如 [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img))
+- 基于文本和输入图像的引导生成(如 [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/img2img) 和 [`StableDiffusionInstructPix2PixPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix))
+- 类别条件图像生成模型(如 [`DiTPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipe))
+
+## 定性评估
+
+定性评估通常涉及对生成图像的人工评判。评估维度包括构图质量、图文对齐度和空间关系等方面。标准化的提示词能为这些主观指标提供统一基准。DrawBench和PartiPrompts是常用的定性评估提示词数据集,分别由[Imagen](https://imagen.research.google/)和[Parti](https://parti.research.google/)团队提出。
+
+根据[Parti官方网站](https://parti.research.google/)说明:
+
+> PartiPrompts (P2)是我们发布的包含1600多个英文提示词的丰富集合,可用于测量模型在不同类别和挑战维度上的能力。
+
+
+
+PartiPrompts包含以下字段:
+- Prompt(提示词)
+- Category(类别,如"抽象"、"世界知识"等)
+- Challenge(难度等级,如"基础"、"复杂"、"文字与符号"等)
+
+这些基准测试支持对不同图像生成模型进行并排人工对比评估。为此,🧨 Diffusers团队构建了**Open Parti Prompts**——一个基于Parti Prompts的社区驱动型定性评估基准,用于比较顶尖开源diffusion模型:
+- [Open Parti Prompts游戏](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts):展示10个parti提示词对应的4张生成图像,用户选择最符合提示的图片
+- [Open Parti Prompts排行榜](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard):对比当前最优开源diffusion模型的性能榜单
+
+为进行手动图像对比,我们演示如何使用`diffusers`处理部分PartiPrompts提示词。
+
+以下是从不同挑战维度(基础、复杂、语言结构、想象力、文字与符号)采样的提示词示例(使用[PartiPrompts作为数据集](https://huggingface.co/datasets/nateraw/parti-prompts)):
+
+```python
+from datasets import load_dataset
+
+# prompts = load_dataset("nateraw/parti-prompts", split="train")
+# prompts = prompts.shuffle()
+# sample_prompts = [prompts[i]["Prompt"] for i in range(5)]
+
+# Fixing these sample prompts in the interest of reproducibility.
+sample_prompts = [
+ "a corgi",
+ "a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky",
+ "a car with no windows",
+ "a cube made of porcupine",
+ 'The saying "BE EXCELLENT TO EACH OTHER" written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.',
+]
+```
+
+现在我们可以使用Stable Diffusion([v1-4 checkpoint](https://huggingface.co/CompVis/stable-diffusion-v1-4))生成这些提示词对应的图像:
+
+```python
+import torch
+
+seed = 0
+generator = torch.manual_seed(seed)
+
+images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generator).images
+```
+
+
+
+我们也可以通过设置`num_images_per_prompt`参数来比较同一提示词生成的不同图像。使用不同检查点([v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5))运行相同流程后,结果如下:
+
+
+
+当使用多个待评估模型为所有提示词生成若干图像后,这些结果将提交给人类评估员进行打分。有关DrawBench和PartiPrompts基准测试的更多细节,请参阅各自的论文。
+
+> [!TIP]
+> 在模型训练过程中查看推理样本有助于评估训练进度。我们的[训练脚本](https://github.com/huggingface/diffusers/tree/main/examples/)支持此功能,并额外提供TensorBoard和Weights & Biases日志记录功能。
+
+## 定量评估
+
+本节将指导您如何评估三种不同的扩散流程,使用以下指标:
+- CLIP分数
+- CLIP方向相似度
+- FID(弗雷歇起始距离)
+
+### 文本引导图像生成
+
+[CLIP分数](https://huggingface.co/papers/2104.08718)用于衡量图像-标题对的匹配程度。CLIP分数越高表明匹配度越高🔼。该分数是对"匹配度"这一定性概念的量化测量,也可以理解为图像与标题之间的语义相似度。研究发现CLIP分数与人类判断具有高度相关性。
+
+首先加载[`StableDiffusionPipeline`]:
+
+```python
+from diffusers import StableDiffusionPipeline
+import torch
+
+model_ckpt = "CompVis/stable-diffusion-v1-4"
+sd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to("cuda")
+```
+
+使用多个提示词生成图像:
+
+```python
+prompts = [
+ "a photo of an astronaut riding a horse on mars",
+ "A high tech solarpunk utopia in the Amazon rainforest",
+ "A pikachu fine dining with a view to the Eiffel Tower",
+ "A mecha robot in a favela in expressionist style",
+ "an insect robot preparing a delicious meal",
+ "A small cabin on top of a snowy mountain in the style of Disney, artstation",
+]
+
+images = sd_pipeline(prompts, num_images_per_prompt=1, output_type="np").images
+
+print(images.shape)
+# (6, 512, 512, 3)
+```
+
+然后计算CLIP分数:
+
+```python
+from torchmetrics.functional.multimodal import clip_score
+from functools import partial
+
+clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
+
+def calculate_clip_score(images, prompts):
+ images_int = (images * 255).astype("uint8")
+ clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
+ return round(float(clip_score), 4)
+
+sd_clip_score = calculate_clip_score(images, prompts)
+print(f"CLIP分数: {sd_clip_score}")
+# CLIP分数: 35.7038
+```
+
+上述示例中,我们为每个提示生成一张图像。如果为每个提示生成多张图像,则需要计算每个提示生成图像的平均分数。
+
+当需要比较两个兼容[`StableDiffusionPipeline`]的检查点时,应在调用管道时传入生成器。首先使用[v1-4 Stable Diffusion检查点](https://huggingface.co/CompVis/stable-diffusion-v1-4)以固定种子生成图像:
+
+```python
+seed = 0
+generator = torch.manual_seed(seed)
+
+images = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
+```
+
+然后加载[v1-5检查点](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)生成图像:
+
+```python
+model_ckpt_1_5 = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=torch.float16).to("cuda")
+
+images_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
+```
+
+最后比较两者的CLIP分数:
+
+```python
+sd_clip_score_1_4 = calculate_clip_score(images, prompts)
+print(f"v-1-4版本的CLIP分数: {sd_clip_score_1_4}")
+# v-1-4版本的CLIP分数: 34.9102
+
+sd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts)
+print(f"v-1-5版本的CLIP分数: {sd_clip_score_1_5}")
+# v-1-5版本的CLIP分数: 36.2137
+```
+
+结果表明[v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)检查点性能优于前代。但需注意,我们用于计算CLIP分数的提示词数量较少。实际评估时应使用更多样化且数量更大的提示词集。
+
+> [!WARNING]
+> 该分数存在固有局限性:训练数据中的标题是从网络爬取,并提取自图片关联的`alt`等标签。这些描述未必符合人类描述图像的方式,因此我们需要人工"设计"部分提示词。
+
+### 图像条件式文本生成图像
+
+这种情况下,生成管道同时接受输入图像和文本提示作为条件。以[`StableDiffusionInstructPix2PixPipeline`]为例,该管道接收编辑指令作为输入提示,并接受待编辑的输入图像。
+
+示例图示:
+
+
+
+评估此类模型的策略之一是测量两幅图像间变化的连贯性(通过[CLIP](https://huggingface.co/docs/transformers/model_doc/clip)定义)中两个图像之间的变化与两个图像描述之间的变化的一致性(如论文[《CLIP-Guided Domain Adaptation of Image Generators》](https://huggingface.co/papers/2108.00946)所示)。这被称为“**CLIP方向相似度**”。
+
+- **描述1**对应输入图像(图像1),即待编辑的图像。
+- **描述2**对应编辑后的图像(图像2),应反映编辑指令。
+
+以下是示意图:
+
+
+
+我们准备了一个小型数据集来实现该指标。首先加载数据集:
+
+```python
+from datasets import load_dataset
+
+dataset = load_dataset("sayakpaul/instructpix2pix-demo", split="train")
+dataset.features
+```
+
+```bash
+{'input': Value(dtype='string', id=None),
+ 'edit': Value(dtype='string', id=None),
+ 'output': Value(dtype='string', id=None),
+ 'image': Image(decode=True, id=None)}
+```
+
+数据字段说明:
+
+- `input`:与`image`对应的原始描述。
+- `edit`:编辑指令。
+- `output`:反映`edit`指令的修改后描述。
+
+查看一个样本:
+
+```python
+idx = 0
+print(f"Original caption: {dataset[idx]['input']}")
+print(f"Edit instruction: {dataset[idx]['edit']}")
+print(f"Modified caption: {dataset[idx]['output']}")
+```
+
+```bash
+Original caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles in the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'
+Edit instruction: make the isles all white marble
+Modified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles in the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'
+```
+
+对应的图像:
+
+```python
+dataset[idx]["image"]
+```
+
+
+
+我们将根据编辑指令修改数据集中的图像,并计算方向相似度。
+
+首先加载[`StableDiffusionInstructPix2PixPipeline`]:
+
+```python
+from diffusers import StableDiffusionInstructPix2PixPipeline
+
+instruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
+ "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
+).to("cuda")
+```
+
+执行编辑操作:
+
+```python
+import numpy as np
+
+
+def edit_image(input_image, instruction):
+ image = instruct_pix2pix_pipeline(
+ instruction,
+ image=input_image,
+ output_type="np",
+ generator=generator,
+ ).images[0]
+ return image
+
+input_images = []
+original_captions = []
+modified_captions = []
+edited_images = []
+
+for idx in range(len(dataset)):
+ input_image = dataset[idx]["image"]
+ edit_instruction = dataset[idx]["edit"]
+ edited_image = edit_image(input_image, edit_instruction)
+
+ input_images.append(np.array(input_image))
+ original_captions.append(dataset[idx]["input"])
+ modified_captions.append(dataset[idx]["output"])
+ edited_images.append(edited_image)
+```
+
+为测量方向相似度,我们首先加载CLIP的图像和文本编码器:
+
+```python
+from transformers import (
+ CLIPTokenizer,
+ CLIPTextModelWithProjection,
+ CLIPVisionModelWithProjection,
+ CLIPImageProcessor,
+)
+
+clip_id = "openai/clip-vit-large-patch14"
+tokenizer = CLIPTokenizer.from_pretrained(clip_id)
+text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to("cuda")
+image_processor = CLIPImageProcessor.from_pretrained(clip_id)
+image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to("cuda")
+```
+
+注意我们使用的是特定CLIP检查点——`openai/clip-vit-large-patch14`,因为Stable Diffusion预训练正是基于此CLIP变体。详见[文档](https://huggingface.co/docs/transformers/model_doc/clip)。
+
+接着准备计算方向相似度的PyTorch `nn.Module`:
+
+```python
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DirectionalSimilarity(nn.Module):
+ def __init__(self, tokenizer, text_encoder, image_processor, image_encoder):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.text_encoder = text_encoder
+ self.image_processor = image_processor
+ self.image_encoder = image_encoder
+
+ def preprocess_image(self, image):
+ image = self.image_processor(image, return_tensors="pt")["pixel_values"]
+ return {"pixel_values": image.to("cuda")}
+
+ def tokenize_text(self, text):
+ inputs = self.tokenizer(
+ text,
+ max_length=self.tokenizer.model_max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ )
+ return {"input_ids": inputs.input_ids.to("cuda")}
+
+ def encode_image(self, image):
+ preprocessed_image = self.preprocess_image(image)
+ image_features = self.image_encoder(**preprocessed_image).image_embeds
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
+ return image_features
+
+ def encode_text(self, text):
+ tokenized_text = self.tokenize_text(text)
+ text_features = self.text_encoder(**tokenized_text).text_embeds
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
+ return text_features
+
+ def compute_directional_similarity(self, img_feat_one, img_feat_two, text_feat_one, text_feat_two):
+ sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one)
+ return sim_direction
+
+ def forward(self, image_one, image_two, caption_one, caption_two):
+ img_feat_one = self.encode_image(image_one)
+ img_feat_two = self.encode_image(image_two)
+ text_feat_one = self.encode_text(caption_one)
+ text_feat_two = self.encode_text(caption_two)
+ directional_similarity = self.compute_directional_similarity(
+ img_feat_one, img_feat_two, text_feat_one, text_feat_two
+ )
+ return directional_similarity
+```
+
+现在让我们使用`DirectionalSimilarity`模块:
+
+```python
+dir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder)
+scores = []
+
+for i in range(len(input_images)):
+ original_image = input_images[i]
+ original_caption = original_captions[i]
+ edited_image = edited_images[i]
+ modified_caption = modified_captions[i]
+
+ similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption)
+ scores.append(float(similarity_score.detach().cpu()))
+
+print(f"CLIP方向相似度: {np.mean(scores)}")
+# CLIP方向相似度: 0.0797976553440094
+```
+
+与CLIP分数类似,CLIP方向相似度数值越高越好。
+
+需要注意的是,`StableDiffusionInstructPix2PixPipeline`提供了两个控制参数`image_guidance_scale`和`guidance_scale`来调节最终编辑图像的质量。建议您尝试调整这两个参数,观察它们对方向相似度的影响。
+
+我们可以扩展这个度量标准来评估原始图像与编辑版本的相似度,只需计算`F.cosine_similarity(img_feat_two, img_feat_one)`。对于这类编辑任务,我们仍希望尽可能保留图像的主要语义特征(即保持较高的相似度分数)。
+
+该度量方法同样适用于类似流程,例如[`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline)。
+
+> [!TIP]
+> CLIP分数和CLIP方向相似度都依赖CLIP模型,可能导致评估结果存在偏差。
+
+***扩展IS、FID(后文讨论)或KID等指标存在困难***,当被评估模型是在大型图文数据集(如[LAION-5B数据集](https://laion.ai/blog/laion-5b/))上预训练时。因为这些指标的底层都使用了在ImageNet-1k数据集上预训练的InceptionNet来提取图像特征。Stable Diffusion的预训练数据集与InceptionNet的预训练数据集可能重叠有限,因此不适合作为特征提取器。
+
+***上述指标更适合评估类别条件模型***,例如[DiT](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit)。该模型是在ImageNet-1k类别条件下预训练的。
+这是9篇文档中的第8部分。
+
+### 基于类别的图像生成
+
+基于类别的生成模型通常是在带有类别标签的数据集(如[ImageNet-1k](https://huggingface.co/datasets/imagenet-1k))上进行预训练的。评估这些模型的常用指标包括Fréchet Inception Distance(FID)、Kernel Inception Distance(KID)和Inception Score(IS)。本文档重点介绍FID([Heusel等人](https://huggingface.co/papers/1706.08500)),并展示如何使用[`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit)计算该指标,该管道底层使用了[DiT模型](https://huggingface.co/papers/2212.09748)。
+
+FID旨在衡量两组图像数据集的相似程度。根据[此资源](https://mmgeneration.readthedocs.io/en/latest/quick_run.html#fid):
+
+> Fréchet Inception Distance是衡量两组图像数据集相似度的指标。研究表明其与人类对视觉质量的主观判断高度相关,因此最常用于评估生成对抗网络(GAN)生成样本的质量。FID通过计算Inception网络特征表示所拟合的两个高斯分布之间的Fréchet距离来实现。
+
+这两个数据集本质上是真实图像数据集和生成图像数据集(本例中为人工生成的图像)。FID通常基于两个大型数据集计算,但本文档将使用两个小型数据集进行演示。
+
+首先下载ImageNet-1k训练集中的部分图像:
+
+```python
+from zipfile import ZipFile
+import requests
+
+
+def download(url, local_filepath):
+ r = requests.get(url)
+ with open(local_filepath, "wb") as f:
+ f.write(r.content)
+ return local_filepath
+
+dummy_dataset_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip"
+local_filepath = download(dummy_dataset_url, dummy_dataset_url.split("/")[-1])
+
+with ZipFile(local_filepath, "r") as zipper:
+ zipper.extractall(".")
+```
+
+```python
+from PIL import Image
+import os
+import numpy as np
+
+dataset_path = "sample-imagenet-images"
+image_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])
+
+real_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths]
+```
+
+这些是来自以下ImageNet-1k类别的10张图像:"cassette_player"、"chain_saw"(2张)、"church"、"gas_pump"(3张)、"parachute"(2张)和"tench"。
+
+
+
+ 真实图像
+
+
+加载图像后,我们对其进行轻量级预处理以便用于FID计算:
+
+```python
+from torchvision.transforms import functional as F
+import torch
+
+
+def preprocess_image(image):
+ image = torch.tensor(image).unsqueeze(0)
+ image = image.permute(0, 3, 1, 2) / 255.0
+ return F.center_crop(image, (256, 256))
+
+real_images = torch.stack([dit_pipeline.preprocess_image(image) for image in real_images])
+print(real_images.shape)
+# torch.Size([10, 3, 256, 256])
+```
+
+我们现在加载[`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit)来生成基于上述类别的条件图像。
+
+```python
+from diffusers import DiTPipeline, DPMSolverMultistepScheduler
+
+dit_pipeline = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
+dit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)
+dit_pipeline = dit_pipeline.to("cuda")
+
+seed = 0
+generator = torch.manual_seed(seed)
+
+
+words = [
+ "cassette player",
+ "chainsaw",
+ "chainsaw",
+ "church",
+ "gas pump",
+ "gas pump",
+ "gas pump",
+ "parachute",
+ "parachute",
+ "tench",
+]
+
+class_ids = dit_pipeline.get_label_ids(words)
+output = dit_pipeline(class_labels=class_ids, generator=generator, output_type="np")
+
+fake_images = output.images
+fake_images = torch.tensor(fake_images)
+fake_images = fake_images.permute(0, 3, 1, 2)
+print(fake_images.shape)
+# torch.Size([10, 3, 256, 256])
+```
+
+现在,我们可以使用[`torchmetrics`](https://torchmetrics.readthedocs.io/)计算FID分数。
+
+```python
+from torchmetrics.image.fid import FrechetInceptionDistance
+
+fid = FrechetInceptionDistance(normalize=True)
+fid.update(real_images, real=True)
+fid.update(fake_images, real=False)
+
+print(f"FID分数: {float(fid.compute())}")
+# FID分数: 177.7147216796875
+```
+
+FID分数越低越好。以下因素会影响FID结果:
+
+- 图像数量(包括真实图像和生成图像)
+- 扩散过程中引入的随机性
+- 扩散过程的推理步数
+- 扩散过程中使用的调度器
+
+对于最后两点,最佳实践是使用不同的随机种子和推理步数进行多次评估,然后报告平均结果。
+
+> [!WARNING]
+> FID结果往往具有脆弱性,因为它依赖于许多因素:
+>
+> * 计算过程中使用的特定Inception模型
+> * 计算实现的准确性
+> * 图像格式(PNG和JPG的起点不同)
+>
+> 需要注意的是,FID通常在比较相似实验时最有用,但除非作者仔细公开FID测量代码,否则很难复现论文结果。
+>
+> 这些注意事项同样适用于其他相关指标,如KID和IS。
+
+最后,让我们可视化检查这些`fake_images`。
+
+
+
+ 生成图像示例
+
diff --git a/docs/source/zh/conceptual/philosophy.md b/docs/source/zh/conceptual/philosophy.md
new file mode 100644
index 000000000000..581e582bba56
--- /dev/null
+++ b/docs/source/zh/conceptual/philosophy.md
@@ -0,0 +1,104 @@
+
+
+# 设计哲学
+
+🧨 Diffusers 提供**最先进**的预训练扩散模型支持多模态任务。
+其目标是成为推理和训练通用的**模块化工具箱**。
+
+我们致力于构建一个经得起时间考验的库,因此对API设计极为重视。
+
+简而言之,Diffusers 被设计为 PyTorch 的自然延伸。因此,我们的多数设计决策都基于 [PyTorch 设计原则](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)。以下是核心原则:
+
+## 可用性优先于性能
+
+- 尽管 Diffusers 包含众多性能优化特性(参见[内存与速度优化](https://huggingface.co/docs/diffusers/optimization/fp16)),模型默认总是以最高精度和最低优化级别加载。因此除非用户指定,扩散流程(pipeline)默认在CPU上以float32精度初始化。这确保了跨平台和加速器的可用性,意味着运行本库无需复杂安装。
+- Diffusers 追求**轻量化**,仅有少量必需依赖,但提供诸多可选依赖以提升性能(如`accelerate`、`safetensors`、`onnx`等)。我们竭力保持库的轻量级特性,使其能轻松作为其他包的依赖项。
+- Diffusers 偏好简单、自解释的代码而非浓缩的"魔法"代码。这意味着lambda函数等简写语法和高级PyTorch操作符通常不被采用。
+
+## 简洁优于简易
+
+正如PyTorch所言:**显式优于隐式**,**简洁优于复杂**。这一哲学体现在库的多个方面:
+- 我们遵循PyTorch的API设计,例如使用[`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)让用户自主管理设备。
+- 明确的错误提示优于静默纠正错误输入。Diffusers 旨在教育用户,而非单纯降低使用难度。
+- 暴露复杂的模型与调度器(scheduler)交互逻辑而非内部魔法处理。调度器/采样器与扩散模型分离且相互依赖最小化,迫使用户编写展开的去噪循环。但这种分离便于调试,并赋予用户更多控制权来调整去噪过程或切换模型/调度器。
+- 扩散流程中独立训练的组件(如文本编码器、UNet、变分自编码器)各有专属模型类。这要求用户处理组件间交互,且序列化格式将组件分存不同文件。但此举便于调试和定制,得益于组件分离,DreamBooth或Textual Inversion训练变得极为简单。
+
+## 可定制与贡献友好优于抽象
+
+库的大部分沿用了[Transformers库](https://github.com/huggingface/transformers)的重要设计原则:宁要重复代码,勿要仓促抽象。这一原则与[DRY原则](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)形成鲜明对比。
+
+简言之,正如Transformers对建模文件的做法,Diffusers对流程(pipeline)和调度器(scheduler)保持极低抽象度与高度自包含代码。函数、长代码块甚至类可能在多文件中重复,初看像是糟糕的松散设计。但该设计已被Transformers证明极其成功,对社区驱动的开源机器学习库意义重大:
+- 机器学习领域发展迅猛,范式、模型架构和算法快速迭代,难以定义长效代码抽象。
+- ML从业者常需快速修改现有代码进行研究,因此偏好自包含代码而非多重抽象。
+- 开源库依赖社区贡献,必须构建易于参与的代码库。抽象度越高、依赖越复杂、可读性越差,贡献难度越大。过度抽象的库会吓退贡献者。若贡献不会破坏核心功能,不仅吸引新贡献者,也更便于并行审查和修改。
+
+Hugging Face称此设计为**单文件政策**——即某个类的几乎所有代码都应写在单一自包含文件中。更多哲学探讨可参阅[此博文](https://huggingface.co/blog/transformers-design-philosophy)。
+
+Diffusers对流程和调度器完全遵循该哲学,但对diffusion模型仅部分适用。原因在于多数扩散流程(如[DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm)、[Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines)、[unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip)和[Imagen](https://imagen.research.google/))都基于相同扩散模型——[UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond)。
+
+现在您应已理解🧨 Diffusers的设计理念🤗。我们力求在全库贯彻这些原则,但仍存在少数例外或欠佳设计。如有反馈,我们❤️欢迎在[GitHub提交](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=)。
+
+## 设计哲学细节
+
+现在深入探讨设计细节。Diffusers主要包含三类:[流程(pipeline)](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines)、[模型](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models)和[调度器(scheduler)](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)。以下是各类的具体设计决策。
+
+### 流程(Pipelines)
+
+流程设计追求易用性(因此不完全遵循[*简洁优于简易*](#简洁优于简易)),不要求功能完备,应视为使用[模型](#模型)和[调度器](#调度器schedulers)进行推理的示例。
+
+遵循原则:
+- 采用单文件政策。所有流程位于src/diffusers/pipelines下的独立目录。一个流程文件夹对应一篇扩散论文/项目/发布。如[`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion)可包含多个流程文件。若流程功能相似,可使用[# Copied from机制](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251)。
+- 所有流程继承[`DiffusionPipeline`]。
+- 每个流程由不同模型和调度器组件构成,这些组件记录于[`model_index.json`文件](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json),可通过同名属性访问,并可用[`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components)在流程间共享。
+- 所有流程应能通过[`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained)加载。
+- 流程**仅**用于推理。
+- 流程代码应具备高可读性、自解释性和易修改性。
+- 流程应设计为可相互构建,便于集成到高层API。
+- 流程**非**功能完备的用户界面。完整UI推荐[InvokeAI](https://github.com/invoke-ai/InvokeAI)、[Diffuzers](https://github.com/abhishekkrthakur/diffuzers)或[lama-cleaner](https://github.com/Sanster/lama-cleaner)。
+- 每个流程应通过唯一的`__call__`方法运行,且参数命名应跨流程统一。
+- 流程应以其解决的任务命名。
+- 几乎所有新diffusion流程都应在新文件夹/文件中实现。
+
+### 模型
+
+模型设计为可配置的工具箱,是[PyTorch Module类](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)的自然延伸,仅部分遵循**单文件政策**。
+
+遵循原则:
+- 模型对应**特定架构类型**。如[`UNet2DConditionModel`]类适用于所有需要2D图像输入且受上下文调节的UNet变体。
+- 所有模型位于[`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models),每种架构应有独立文件,如[`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py)、[`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py)等。
+- 模型**不**采用单文件政策,应使用小型建模模块如[`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py)、[`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py)、[`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py)等。**注意**:这与Transformers的建模文件截然不同,表明模型未完全遵循单文件政策。
+- 模型意图暴露复杂度(类似PyTorch的`Module`类),并提供明确错误提示。
+- 所有模型继承`ModelMixin`和`ConfigMixin`。
+- 当不涉及重大代码变更、保持向后兼容性且显著提升内存/计算效率时,可对模型进行性能优化。
+- 模型默认应具备最高精度和最低性能设置。
+- 若新模型检查点可归类为现有架构,应适配现有架构而非新建文件。仅当架构根本性不同时才创建新文件。
+- 模型设计应便于未来扩展。可通过限制公开函数参数、配置参数和"预见"变更实现。例如:优先采用可扩展的`string`类型参数而非布尔型`is_..._type`参数。对现有架构的修改应保持最小化。
+- 模型设计需在代码可读性与多检查点支持间权衡。多数情况下应适配现有类,但某些例外(如[UNet块](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py)和[注意力处理器](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py))需新建类以保证长期可读性。
+
+### 调度器(Schedulers)
+
+调度器负责引导推理去噪过程及定义训练噪声计划。它们设计为独立的可加载配置类,严格遵循**单文件政策**。
+
+遵循原则:
+- 所有调度器位于[`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)。
+- 调度器**禁止**从大型工具文件导入,必须保持高度自包含。
+- 一个调度器Python文件对应一种算法(如论文定义的算法)。
+- 若调度器功能相似,可使用`# Copied from`机制。
+- 所有调度器继承`SchedulerMixin`和`ConfigMixin`。
+- 调度器可通过[`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config)轻松切换(详见[此处](../using-diffusers/schedulers))。
+- 每个调度器必须包含`set_num_inference_steps`和`step`函数。在每次去噪过程前(即调用`step(...)`前)必须调用`set_num_inference_steps(...)`。
+- 每个调度器通过`timesteps`属性暴露需要"循环"的时间步,这是模型将被调用的时间步数组。
+- `step(...)`函数接收模型预测输出和"当前"样本(x_t),返回"前一个"略去噪的样本(x_t-1)。
+- 鉴于扩散调度器的复杂性,`step`函数不暴露全部细节,可视为"黑盒"。
+- 几乎所有新调度器都应在新文件中实现。
\ No newline at end of file
diff --git a/docs/source/zh/hybrid_inference/api_reference.md b/docs/source/zh/hybrid_inference/api_reference.md
new file mode 100644
index 000000000000..74f6a35ec2a1
--- /dev/null
+++ b/docs/source/zh/hybrid_inference/api_reference.md
@@ -0,0 +1,9 @@
+# 混合推理 API 参考
+
+## 远程解码
+
+[[autodoc]] utils.remote_utils.remote_decode
+
+## 远程编码
+
+[[autodoc]] utils.remote_utils.remote_encode
\ No newline at end of file
diff --git a/docs/source/zh/hybrid_inference/overview.md b/docs/source/zh/hybrid_inference/overview.md
new file mode 100644
index 000000000000..4d77d0abc26d
--- /dev/null
+++ b/docs/source/zh/hybrid_inference/overview.md
@@ -0,0 +1,55 @@
+
+
+# 混合推理
+
+**通过混合推理赋能本地 AI 构建者**
+
+> [!TIP]
+> 混合推理是一项[实验性功能](https://huggingface.co/blog/remote_vae)。
+> 可以在此处提供反馈[此处](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml)。
+
+## 为什么使用混合推理?
+
+混合推理提供了一种快速简单的方式来卸载本地生成需求。
+
+- 🚀 **降低要求:** 无需昂贵硬件即可访问强大模型。
+- 💎 **无妥协:** 在不牺牲性能的情况下实现最高质量。
+- 💰 **成本效益高:** 它是免费的!🤑
+- 🎯 **多样化用例:** 与 Diffusers � 和更广泛的社区完全兼容。
+- 🔧 **开发者友好:** 简单请求,快速响应。
+
+---
+
+## 可用模型
+
+* **VAE 解码 🖼️:** 快速将潜在表示解码为高质量图像,不影响性能或工作流速度。
+* **VAE 编码 🔢:** 高效将图像编码为潜在表示,用于生成和训练。
+* **文本编码器 📃(即将推出):** 快速准确地计算提示的文本嵌入,确保流畅高质量的工作流。
+
+---
+
+## 集成
+
+* **[SD.Next](https://github.com/vladmandic/sdnext):** 一体化 UI,直接支持混合推理。
+* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** 用于混合推理的 ComfyUI 节点。
+
+## 更新日志
+
+- 2025 年 3 月 10 日:添加了 VAE 编码
+- 2025 年 3 月 2 日:初始发布,包含 VAE 解码
+
+## 内容
+
+文档分为三个部分:
+
+* **VAE 解码** 学习如何使用混合推理进行 VAE 解码的基础知识。
+* **VAE 编码** 学习如何使用混合推理进行 VAE 编码的基础知识。
+* **API 参考** 深入了解任务特定设置和参数。
\ No newline at end of file
diff --git a/docs/source/zh/hybrid_inference/vae_encode.md b/docs/source/zh/hybrid_inference/vae_encode.md
new file mode 100644
index 000000000000..30aee9a6bfa4
--- /dev/null
+++ b/docs/source/zh/hybrid_inference/vae_encode.md
@@ -0,0 +1,184 @@
+# 入门:使用混合推理进行 VAE 编码
+
+VAE 编码用于训练、图像到图像和图像到视频——将图像或视频转换为潜在表示。
+
+## 内存
+
+这些表格展示了在不同 GPU 上使用 SD v1 和 SD XL 进行 VAE 编码的 VRAM 需求。
+
+对于这些 GPU 中的大多数,内存使用百分比决定了其他模型(文本编码器、UNet/Transformer)必须被卸载,或者必须使用分块编码,这会增加时间并影响质量。
+
+SD v1.5
+
+| GPU | 分辨率 | 时间(秒) | 内存(%) | 分块时间(秒) | 分块内存(%) |
+|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
+| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
+| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
+| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
+| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
+| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
+| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
+| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
+| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
+| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
+| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
+| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 |
+| 3.99743 | 0.015 | 3.99743 |
+| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
+
+
+
+SDXL
+
+| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
+|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
+| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
+| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
+| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
+| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
+| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
+| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
+| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
+| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
+| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
+| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
+| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
+| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
+| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
+| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
+| NVIDIA GeForce RTX 3080 | 2048x2048 | 内存不足 (OOM) | 内存不足 (OOM) | 1.866 | 44.4717 |
+| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
+| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
+| NVIDIA GeForce R
+| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
+| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
+| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
+
+
+
+## 可用 VAE
+
+| | **端点** | **模型** |
+|:-:|:-----------:|:--------:|
+| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
+| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
+| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
+
+
+> [!TIP]
+> 模型支持可以在此处请求:[这里](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml)。
+
+
+## 代码
+
+> [!TIP]
+> 从 `main` 安装 `diffusers` 以运行代码:`pip install git+https://github.com/huggingface/diffusers@main`
+
+
+一个辅助方法简化了与混合推理的交互。
+
+```python
+from diffusers.utils.remote_utils import remote_encode
+```
+
+### 基本示例
+
+让我们编码一张图像,然后解码以演示。
+
+
+
+
+
+代码
+
+```python
+from diffusers.utils import load_image
+from diffusers.utils.remote_utils import remote_decode
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
+
+latent = remote_encode(
+ endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+
+decoded = remote_decode(
+ endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+)
+```
+
+
+
+
+
+
+
+
+### 生成
+
+现在让我们看一个生成示例,我们将编码图像,生成,然后远程解码!
+
+代码
+
+```python
+import torch
+from diffusers import StableDiffusionImg2ImgPip
+from diffusers.utils import load_image
+from diffusers.utils.remote_utils import remote_decode, remote_encode
+
+pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ vae=None,
+).to("cuda")
+
+init_image = load_image(
+ "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+)
+init_image = init_image.resize((768, 512))
+
+init_latent = remote_encode(
+ endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
+ image=init_image,
+ scaling_factor=0.18215,
+)
+
+prompt = "A fantasy landscape, trending on artstation"
+latent = pipe(
+ prompt=prompt,
+ image=init_latent,
+ strength=0.75,
+ output_type="latent",
+).images
+
+image = remote_decode(
+ endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
+ tensor=latent,
+ scaling_factor=0.18215,
+)
+image.save("fantasy_landscape.jpg")
+```
+
+
+
+
+
+
+
+## 集成
+
+* **[SD.Next](https://github.com/vladmandic/sdnext):** 具有直接支持混合推理功能的一体化用户界面。
+* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** 用于混合推理的 ComfyUI 节点。
\ No newline at end of file
diff --git a/docs/source/zh/index.md b/docs/source/zh/index.md
index 92c52bc1c1a9..299941e8227f 100644
--- a/docs/source/zh/index.md
+++ b/docs/source/zh/index.md
@@ -1,4 +1,4 @@
-
+
+# AutoPipelineBlocks
+
+[`~modular_pipelines.AutoPipelineBlocks`] 是一种包含支持不同工作流程的块的多块类型。它根据运行时提供的输入自动选择要运行的子块。这通常用于将多个工作流程(文本到图像、图像到图像、修复)打包到一个管道中以便利。
+
+本指南展示如何创建 [`~modular_pipelines.AutoPipelineBlocks`]。
+
+创建三个 [`~modular_pipelines.ModularPipelineBlocks`] 用于文本到图像、图像到图像和修复。这些代表了管道中可用的不同工作流程。
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
+
+class TextToImageBlock(ModularPipelineBlocks):
+ model_name = "text2img"
+
+ @property
+ def inputs(self):
+ return [InputParam(name="prompt")]
+
+ @property
+ def intermediate_outputs(self):
+ return []
+
+ @property
+ def description(self):
+ return "我是一个文本到图像的工作流程!"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ print("运行文本到图像工作流程")
+ # 在这里添加你的文本到图像逻辑
+ # 例如:根据提示生成图像
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+
+```py
+class ImageToImageBlock(ModularPipelineBlocks):
+ model_name = "img2img"
+
+ @property
+ def inputs(self):
+ return [InputParam(name="prompt"), InputParam(name="image")]
+
+ @property
+ def intermediate_outputs(self):
+ return []
+
+ @property
+ def description(self):
+ return "我是一个图像到图像的工作流程!"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ print("运行图像到图像工作流程")
+ # 在这里添加你的图像到图像逻辑
+ # 例如:根据提示转换输入图像
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+
+```py
+class InpaintBlock(ModularPipelineBlocks):
+ model_name = "inpaint"
+
+ @property
+ def inputs(self):
+ return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
+
+ @property
+
+ def intermediate_outputs(self):
+ return []
+
+ @property
+ def description(self):
+ return "我是一个修复工作流!"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ print("运行修复工作流")
+ # 在这里添加你的修复逻辑
+ # 例如:根据提示填充被遮罩的区域
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+创建一个包含子块类及其对应块名称列表的[`~modular_pipelines.AutoPipelineBlocks`]类。
+
+你还需要包括`block_trigger_inputs`,一个触发相应块的输入名称列表。如果在运行时提供了触发输入,则选择该块运行。使用`None`来指定如果未检测到触发输入时运行的默认块。
+
+最后,重要的是包括一个`description`,清楚地解释哪些输入触发哪些工作流。这有助于用户理解如何运行特定的工作流。
+
+```py
+from diffusers.modular_pipelines import AutoPipelineBlocks
+
+class AutoImageBlocks(AutoPipelineBlocks):
+ # 选择子块类的列表
+ block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
+ # 每个块的名称,顺序相同
+ block_names = ["inpaint", "img2img", "text2img"]
+ # 决定运行哪个块的触发输入
+ # - "mask" 触发修复工作流
+ # - "image" 触发img2img工作流(但仅在未提供mask时)
+ # - 如果以上都没有,运行text2img工作流(默认)
+ block_trigger_inputs = ["mask", "image", None]
+ # 对于AutoPipelineBlocks来说,描述极其重要
+
+ def description(self):
+ return (
+ "Pipeline generates images given different types of conditions!\n"
+ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
+ + " - inpaint workflow is run when `mask` is provided.\n"
+ + " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
+ + " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
+ )
+```
+
+包含`description`以避免任何关于如何运行块和需要什么输入的混淆**非常**重要。虽然[`~modular_pipelines.AutoPipelineBlocks`]很方便,但如果它没有正确解释,其条件逻辑可能难以理解。
+
+创建`AutoImageBlocks`的一个实例。
+
+```py
+auto_blocks = AutoImageBlocks()
+```
+
+对于更复杂的组合,例如在更大的管道中作为子块使用的嵌套[`~modular_pipelines.AutoPipelineBlocks`]块,使用[`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`]方法根据你的输入提取实际运行的块。
+
+```py
+auto_blocks.get_execution_blocks("mask")
+```
diff --git a/docs/source/zh/modular_diffusers/components_manager.md b/docs/source/zh/modular_diffusers/components_manager.md
new file mode 100644
index 000000000000..39fef0651dd8
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/components_manager.md
@@ -0,0 +1,188 @@
+
+
+# 组件管理器
+
+[`ComponentsManager`] 是 Modular Diffusers 的模型注册和管理系统。它添加和跟踪模型,存储有用的元数据(模型大小、设备放置、适配器),防止重复模型实例,并支持卸载。
+
+本指南将展示如何使用 [`ComponentsManager`] 来管理组件和设备内存。
+
+## 添加组件
+
+[`ComponentsManager`] 应与 [`ModularPipeline`] 一起创建,在 [`~ModularPipeline.from_pretrained`] 或 [`~ModularPipelineBlocks.init_pipeline`] 中。
+
+> [!TIP]
+> `collection` 参数是可选的,但可以更轻松地组织和管理组件。
+
+
+
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+
+comp = ComponentsManager()
+pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
+```
+
+
+
+
+```py
+from diffusers import ComponentsManager
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+components = ComponentsManager()
+t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
+```
+
+
+
+
+组件仅在调用 [`~ModularPipeline.load_components`] 或 [`~ModularPipeline.load_components`] 时加载和注册。以下示例使用 [`~ModularPipeline.load_components`] 创建第二个管道,重用第一个管道的所有组件,并将其分配到不同的集合。
+
+```py
+pipe.load_components()
+pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
+```
+
+使用 [`~ModularPipeline.null_component_names`] 属性来识别需要加载的任何组件,使用 [`~ComponentsManager.get_components_by_names`] 检索它们,然后调用 [`~ModularPipeline.update_components`] 来添加缺失的组件。
+
+```py
+pipe2.null_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
+
+comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
+pipe2.update_components(**comp_dict)
+```
+
+要添加单个组件,请使用 [`~ComponentsManager.add`] 方法。这会使用唯一 id 注册一个组件。
+
+```py
+from diffusers import AutoModel
+
+text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
+component_id = comp.add("text_encoder", text_encoder)
+comp
+```
+
+使用 [`~ComponentsManager.remove`] 通过其 id 移除一个组件。
+
+```py
+comp.remove("text_encoder_139917733042864")
+```
+
+## 检索组件
+
+[`ComponentsManager`] 提供了几种方法来检索已注册的组件。
+
+### get_one
+
+[`~ComponentsManager.get_one`] 方法返回单个组件,并支持对 `name` 参数进行模式匹配。如果多个组件匹配,[`~ComponentsManager.get_one`] 会返回错误。
+
+| 模式 | 示例 | 描述 |
+|-------------|----------------------------------|-------------------------------------------|
+| exact | `comp.get_one(name="unet")` | 精确名称匹配 |
+| wildcard | `comp.get_one(name="unet*")` | 名称以 "unet" 开头 |
+| exclusion | `comp.get_one(name="!unet")` | 排除名为 "unet" 的组件 |
+| or | `comp.get_one(name="unet|vae")` | 名称为 "unet" 或 "vae" |
+
+[`~ComponentsManager.get_one`] 还通过 `collection` 参数或 `load_id` 参数过滤组件。
+
+```py
+comp.get_one(name="unet", collection="sdxl")
+```
+
+### get_components_by_names
+
+[`~ComponentsManager.get_components_by_names`] 方法接受一个名称列表,并返回一个将名称映射到组件的字典。这在 [`ModularPipeline`] 中特别有用,因为它们提供了所需组件名称的列表,并且返回的字典可以直接传递给 [`~ModularPipeline.update_components`]。
+
+```py
+component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
+{"text_encoder": component1, "unet": component2, "vae": component3}
+```
+
+## 重复检测
+
+建议使用 [`ComponentSpec`] 加载模型组件,以分配具有唯一 id 的组件,该 id 编码了它们的加载参数。这允许 [`ComponentsManager`] 自动检测并防止重复的模型实例,即使不同的对象代表相同的底层检查点。
+
+```py
+from diffusers import ComponentSpec, ComponentsManager
+from transformers import CLIPTextModel
+
+comp = ComponentsManager()
+
+# 为第一个文本编码器创建 ComponentSpec
+spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
+# 为重复的文本编码器创建 ComponentSpec(它是相同的检查点,来自相同的仓库/子文件夹)
+spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", ty
+pe_hint=CLIPTextModel)
+
+# 加载并添加两个组件 - 管理器会检测到它们是同一个模型
+comp.add("text_encoder", spec.load())
+comp.add("text_encoder_duplicated", spec_duplicated.load())
+```
+
+这会返回一个警告,附带移除重复项的说明。
+
+```py
+ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('')`.
+'text_encoder_duplicated_139917580682672'
+```
+
+您也可以不使用 [`ComponentSpec`] 添加组件,并且在大多数情况下,即使您以不同名称添加相同组件,重复检测仍然有效。
+
+然而,当您将相同组件加载到不同对象时,[`ComponentManager`] 无法检测重复项。在这种情况下,您应该使用 [`ComponentSpec`] 加载模型。
+
+```py
+text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
+comp.add("text_encoder", text_encoder_2)
+'text_encoder_139917732983664'
+```
+
+## 集合
+
+集合是为组件分配的标签,用于更好的组织和管理。使用 [`~ComponentsManager.add`] 中的 `collection` 参数将组件添加到集合中。
+
+每个集合中只允许每个名称有一个组件。添加第二个同名组件会自动移除第一个组件。
+
+```py
+from diffusers import ComponentSpec, ComponentsManager
+
+comp = ComponentsManager()
+# 为第一个 UNet 创建 ComponentSpec
+spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
+# 为另一个 UNet 创建 ComponentSpec
+spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
+
+# 将两个 UNet 添加到同一个集合 - 第二个将替换第一个
+comp.add("unet", spec.load(), collection="sdxl")
+comp.add("unet", spec2.load(), collection="sdxl")
+```
+
+这使得在基于节点的系统中工作变得方便,因为您可以:
+
+- 使用 `collection` 标签标记所有从一个节点加载的模型。
+- 当新检查点以相同名称加载时自动替换模型。
+- 当节点被移除时批量删除集合中的所有模型。
+
+## 卸载
+
+[`~ComponentsManager.enable_auto_cpu_offload`] 方法是一种全局卸载策略,适用于所有模型,无论哪个管道在使用它们。一旦启用,您无需担心设备放置,如果您添加或移除组件。
+
+```py
+comp.enable_auto_cpu_offload(device="cuda")
+```
+
+所有模型开始时都在 CPU 上,[`ComponentsManager`] 在需要它们之前将它们移动到适当的设备,并在 GPU 内存不足时将其他模型移回 CPU。
+
+您可以设置自己的规则来决定哪些模型要卸载。
diff --git a/docs/source/zh/modular_diffusers/guiders.md b/docs/source/zh/modular_diffusers/guiders.md
new file mode 100644
index 000000000000..50436f90c4a5
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/guiders.md
@@ -0,0 +1,173 @@
+
+
+# 引导器
+
+[Classifier-free guidance](https://huggingface.co/papers/2207.12598) 引导模型生成更好地匹配提示,通常用于提高生成质量、控制和提示的遵循度。有不同类型的引导方法,在 Diffusers 中,它们被称为*引导器*。与块类似,可以轻松切换和使用不同的引导器以适应不同的用例,而无需重写管道。
+
+本指南将向您展示如何切换引导器、调整引导器参数,以及将它们加载并共享到 Hub。
+
+## 切换引导器
+
+[`ClassifierFreeGuidance`] 是默认引导器,在使用 [`~ModularPipelineBlocks.init_pipeline`] 初始化管道时创建。它通过 `from_config` 创建,这意味着它不需要从模块化存储库加载规范。引导器不会列在 `modular_model_index.json` 中。
+
+使用 [`~ModularPipeline.get_component_spec`] 来检查引导器。
+
+```py
+t2i_pipeline.get_component_spec("guider")
+ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
+```
+
+通过将新引导器传递给 [`~ModularPipeline.update_components`] 来切换到不同的引导器。
+
+> [!TIP]
+> 更改引导器将返回文本,让您知道您正在更改引导器类型。
+> ```bash
+> ModularPipeline.update_components: 添加具有新类型的引导器: PerturbedAttentionGuidance, 先前类型: ClassifierFreeGuidance
+> ```
+
+```py
+from diffusers import LayerSkipConfig, PerturbedAttentionGuidance
+
+config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_attention_scores=True, skip_ff=False)
+guider = PerturbedAttentionGuidance(
+ guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config
+)
+t2i_pipeline.update_components(guider=guider)
+```
+
+再次使用 [`~ModularPipeline.get_component_spec`] 来验证引导器类型是否不同。
+
+```py
+t2i_pipeline.get_component_spec("guider")
+ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
+```
+
+## 加载自定义引导器
+
+已经在 Hub 上保存并带有 `modular_model_index.json` 文件的引导器现在被视为 `from_pretrained` 组件,而不是 `from_config` 组件。
+
+```json
+{
+ "guider": [
+ null,
+ null,
+ {
+ "repo": "YiYiXu/modular-loader-t2i-guider",
+ "revision": null,
+ "subfolder": "pag_guider",
+ "type_hint": [
+ "diffusers",
+ "PerturbedAttentionGuidance"
+ ],
+ "variant": null
+ }
+ ]
+}
+```
+
+引导器只有在调用 [`~ModularPipeline.load_components`] 之后才会创建,基于 `modular_model_index.json` 中的加载规范。
+
+```py
+t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider")
+# 在初始化时未创建
+assert t2i_pipeline.guider is None
+t2i_pipeline.load_components()
+# 加载为 PAG 引导器
+t2i_pipeline.guider
+```
+
+## 更改引导器参数
+
+引导器参数可以通过 [`~ComponentSpec.create`] 方法或 [`~ModularPipeline.update_components`] 方法进行调整。下面的示例更改了 `guidance_scale` 值。
+
+
+
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider = guider_spec.create(guidance_scale=10)
+t2i_pipeline.update_components(guider=guider)
+```
+
+
+
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider_spec.config["guidance_scale"] = 10
+t2i_pipeline.update_components(guider=guider_spec)
+```
+
+
+
+
+## 上传自定义引导器
+
+在自定义引导器上调用 [`~utils.PushToHubMixin.push_to_hub`] 方法,将其分享到 Hub。
+
+```py
+guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
+```
+
+要使此引导器可用于管道,可以修改 `modular_model_index.json` 文件或使用 [`~ModularPipeline.update_components`] 方法。
+
+
+
+
+编辑 `modular_model_index.json` 文件,并添加引导器的加载规范,指向包含引导器配置的文件夹
+例如。
+
+```json
+{
+ "guider": [
+ "diffusers",
+ "PerturbedAttentionGuidance",
+ {
+ "repo": "YiYiXu/modular-loader-t2i-guider",
+ "revision": null,
+ "subfolder": "pag_guider",
+ "type_hint": [
+ "diffusers",
+ "PerturbedAttentionGuidance"
+ ],
+ "variant": null
+ }
+ ],
+```
+
+
+
+
+将 [`~ComponentSpec.default_creation_method`] 更改为 `from_pretrained` 并使用 [`~ModularPipeline.update_components`] 来更新引导器和组件规范以及管道配置。
+
+> [!TIP]
+> 更改创建方法将返回文本,告知您正在将创建类型更改为 `from_pretrained`。
+> ```bash
+> ModularPipeline.update_components: 将引导器的 default_creation_method 从 from_config 更改为 from_pretrained。
+> ```
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider_spec.default_creation_method="from_pretrained"
+guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
+guider_spec.subfolder="pag_guider"
+pag_guider = guider_spec.load()
+t2i_pipeline.update_components(guider=pag_guider)
+```
+
+要使其成为管道的默认引导器,请调用 [`~utils.PushToHubMixin.push_to_hub`]。这是一个可选步骤,如果您仅在本地进行实验,则不需要。
+
+```py
+t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider")
+```
+
+
+
diff --git a/docs/source/zh/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/zh/modular_diffusers/loop_sequential_pipeline_blocks.md
new file mode 100644
index 000000000000..aa9dfc1d7e46
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/loop_sequential_pipeline_blocks.md
@@ -0,0 +1,93 @@
+
+
+# LoopSequentialPipelineBlocks
+
+[`~modular_pipelines.LoopSequentialPipelineBlocks`] 是一种多块类型,它将其他 [`~modular_pipelines.ModularPipelineBlocks`] 以循环方式组合在一起。数据循环流动,使用 `intermediate_inputs` 和 `intermediate_outputs`,并且每个块都是迭代运行的。这通常用于创建一个默认是迭代的去噪循环。
+
+本指南向您展示如何创建 [`~modular_pipelines.LoopSequentialPipelineBlocks`]。
+
+## 循环包装器
+
+[`~modular_pipelines.LoopSequentialPipelineBlocks`],也被称为 *循环包装器*,因为它定义了循环结构、迭代变量和配置。在循环包装器内,您需要以下变量。
+
+- `loop_inputs` 是用户提供的值,等同于 [`~modular_pipelines.ModularPipelineBlocks.inputs`]。
+- `loop_intermediate_inputs` 是来自 [`~modular_pipelines.PipelineState`] 的中间变量,等同于 [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`]。
+- `loop_intermediate_outputs` 是由块创建并添加到 [`~modular_pipelines.PipelineState`] 的新中间变量。它等同于 [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`]。
+- `__call__` 方法定义了循环结构和迭代逻辑。
+
+```py
+import torch
+from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, ModularPipelineBlocks, InputParam, OutputParam
+
+class LoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "test"
+ @property
+ def description(self):
+ return "I'm a loop!!"
+ @property
+ def loop_inputs(self):
+ return [InputParam(name="num_steps")]
+ @torch.no_grad()
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ # 循环结构 - 可以根据您的需求定制
+ for i in range(block_state.num_steps):
+ # loop_step 按顺序执行所有注册的块
+ components, block_state = self.loop_step(components, block_state, i=i)
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+循环包装器可以传递额外的参数,如当前迭代索引,到循环块。
+
+## 循环块
+
+循环块是一个 [`~modular_pipelines.ModularPipelineBlocks`],但 `__call__` 方法的行为不同。
+
+- 它从循环包装器。
+- 它直接与[`~modular_pipelines.BlockState`]一起工作,而不是[`~modular_pipelines.PipelineState`]。
+- 它不需要检索或更新[`~modular_pipelines.BlockState`]。
+
+循环块共享相同的[`~modular_pipelines.BlockState`],以允许值在循环的每次迭代中累积和变化。
+
+```py
+class LoopBlock(ModularPipelineBlocks):
+ model_name = "test"
+ @property
+ def inputs(self):
+ return [InputParam(name="x")]
+ @property
+ def intermediate_outputs(self):
+ # 这个块产生的输出
+ return [OutputParam(name="x")]
+ @property
+ def description(self):
+ return "我是一个在`LoopWrapper`类内部使用的块"
+ def __call__(self, components, block_state, i: int):
+ block_state.x += 1
+ return components, block_state
+```
+
+## LoopSequentialPipelineBlocks
+
+使用[`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`]方法将循环块添加到循环包装器中,以创建[`~modular_pipelines.LoopSequentialPipelineBlocks`]。
+
+```py
+loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock})
+```
+
+添加更多的循环块以在每次迭代中运行,使用[`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`]。这允许您在不改变循环逻辑本身的情况下修改块。
+
+```py
+loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
+```
diff --git a/docs/source/zh/modular_diffusers/modular_diffusers_states.md b/docs/source/zh/modular_diffusers/modular_diffusers_states.md
new file mode 100644
index 000000000000..99503c6387f1
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/modular_diffusers_states.md
@@ -0,0 +1,74 @@
+
+
+# 状态
+
+块依赖于[`~modular_pipelines.PipelineState`]和[`~modular_pipelines.BlockState`]数据结构进行通信和数据共享。
+
+| 状态 | 描述 |
+|-------|-------------|
+| [`~modular_pipelines.PipelineState`] | 维护管道执行所需的整体数据,并允许块读取和更新其数据。 |
+| [`~modular_pipelines.BlockState`] | 允许每个块使用来自`inputs`的必要数据执行其计算 |
+
+本指南解释了状态如何工作以及它们如何连接块。
+
+## PipelineState
+
+[`~modular_pipelines.PipelineState`]是所有块的全局状态容器。它维护管道的完整运行时状态,并为块提供了一种结构化的方式来读取和写入共享数据。
+
+[`~modular_pipelines.PipelineState`]中有两个字典用于结构化数据。
+
+- `values`字典是一个**可变**状态,包含用户提供的输入值的副本和由块生成的中间输出值。如果一个块修改了一个`input`,它将在调用`set_block_state`后反映在`values`字典中。
+
+```py
+PipelineState(
+ values={
+ 'prompt': 'a cat'
+ 'guidance_scale': 7.0
+ 'num_inference_steps': 25
+ 'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))
+ 'negative_prompt_embeds': None
+ },
+)
+```
+
+## BlockState
+
+[`~modular_pipelines.BlockState`]是[`~modular_pipelines.PipelineState`]中相关变量的局部视图,单个块需要这些变量来执行其计算。
+
+直接作为属性访问这些变量,如`block_state.image`。
+
+```py
+BlockState(
+ image:
+)
+```
+
+当一个块的`__call__`方法被执行时,它用`self.get_block_state(state)`检索[`BlockState`],执行其操作,并用`self.set_block_state(state, block_state)`更新[`~modular_pipelines.PipelineState`]。
+
+```py
+def __call__(self, components, state):
+ # 检索BlockState
+ block_state = self.get_block_state(state)
+
+ # 对输入进行计算的逻辑
+
+ # 更新PipelineState
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+## 状态交互
+
+[`~modular_pipelines.PipelineState`]和[`~modular_pipelines.BlockState`]的交互由块的`inputs`和`intermediate_outputs`定义。
+
+- `inputs`,
+一个块可以修改输入 - 比如 `block_state.image` - 并且这个改变可以通过调用 `set_block_state` 全局传播到 [`~modular_pipelines.PipelineState`]。
+- `intermediate_outputs`,是一个块创建的新变量。它被添加到 [`~modular_pipelines.PipelineState`] 的 `values` 字典中,并且可以作为后续块的可用变量,或者由用户作为管道的最终输出访问。
diff --git a/docs/source/zh/modular_diffusers/modular_pipeline.md b/docs/source/zh/modular_diffusers/modular_pipeline.md
new file mode 100644
index 000000000000..a57fdf227506
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/modular_pipeline.md
@@ -0,0 +1,358 @@
+
+
+# 模块化管道
+
+[`ModularPipeline`] 将 [`~modular_pipelines.ModularPipelineBlocks`] 转换为可执行的管道,加载模型并执行块中定义的计算步骤。它是运行管道的主要接口,与 [`DiffusionPipeline`] API 非常相似。
+
+主要区别在于在管道中包含了一个预期的 `output` 参数。
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
+image.save("modular_t2i_out.png")
+```
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
+
+blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+init_image = load_image(url)
+prompt = "a dog catching a frisbee in the jungle"
+image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0]
+image.save("modular_i2i_out.png")
+```
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
+from diffusers.utils import load_image
+
+blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+mask_url = "h
+ttps://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png"
+
+init_image = load_image(img_url)
+mask_image = load_image(mask_url)
+
+prompt = "A deep sea diver floating"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0]
+image.save("moduar_inpaint_out.png")
+```
+
+
+
+
+本指南将向您展示如何创建一个[`ModularPipeline`]并管理其中的组件。
+
+## 添加块
+
+块是[`InsertableDict`]对象,可以在特定位置插入,提供了一种灵活的方式来混合和匹配块。
+
+使用[`~modular_pipelines.modular_pipeline_utils.InsertableDict.insert`]在块类或`sub_blocks`属性上添加一个块。
+
+```py
+# BLOCKS是块类的字典,您需要向其中添加类
+BLOCKS.insert("block_name", BlockClass, index)
+# sub_blocks属性包含实例,向该属性添加一个块实例
+t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
+```
+
+使用[`~modular_pipelines.modular_pipeline_utils.InsertableDict.pop`]在块类或`sub_blocks`属性上移除一个块。
+
+```py
+# 从预设中移除一个块类
+BLOCKS.pop("text_encoder")
+# 分离出一个块实例
+text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
+```
+
+通过将现有块设置为新块来交换块。
+
+```py
+# 在预设中替换块类
+BLOCKS["prepare_latents"] = CustomPrepareLatents
+# 使用块实例在sub_blocks属性中替换
+t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
+```
+
+## 创建管道
+
+有两种方法可以创建一个[`ModularPipeline`]。从[`ModularPipelineBlocks`]组装并创建管道,或使用[`~ModularPipeline.from_pretrained`]加载现有管道。
+
+您还应该初始化一个[`ComponentsManager`]来处理设备放置和内存以及组件管理。
+
+> [!TIP]
+> 有关它如何帮助管理不同工作流中的组件的更多详细信息,请参阅[ComponentsManager](./components_manager)文档。
+
+
+
+
+使用[`~ModularPipelineBlocks.init_pipeline`]方法从组件和配置规范创建一个[`ModularPipeline`]。此方法从`modular_model_index.json`文件加载*规范*,但尚未加载*模型*。
+
+```py
+from diffusers import ComponentsManager
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+components = ComponentsManager()
+t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
+```
+
+
+
+
+[`~ModularPipeline.from_pretrained`]方法创建一个[`ModularPipeline`]从Hub上的模块化仓库加载。
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
+```
+
+添加`trust_remote_code`参数以加载自定义的[`ModularPipeline`]。
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+modular_repo_id = "YiYiXu/modular-diffdiff-0704"
+diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
+```
+
+
+
+
+## 加载组件
+
+一个[`ModularPipeline`]不会自动实例化组件。它只加载配置和组件规范。您可以使用[`~ModularPipeline.load_components`]加载所有组件,或仅使用[`~ModularPipeline.load_components`]加载特定组件。
+
+
+
+
+```py
+import torch
+
+t2i_pipeline.load_components(torch_dtype=torch.float16)
+t2i_pipeline.to("cuda")
+```
+
+
+
+
+下面的例子仅加载UNet和VAE。
+
+```py
+import torch
+
+t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
+```
+
+
+
+
+打印管道以检查加载的预训练组件。
+
+```py
+t2i_pipeline
+```
+
+这应该与管道初始化自的模块化仓库中的`modular_model_index.json`文件匹配。如果管道不需要某个组件,即使它在模块化仓库中存在,也不会被包含。
+
+要修改组件加载的来源,编辑仓库中的`modular_model_index.json`文件,并将其更改为您希望的加载路径。下面的例子从不同的仓库加载UNet。
+
+```json
+# 原始
+"unet": [
+ null, null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "subfolder": "unet",
+ "variant": "fp16"
+ }
+]
+
+# 修改后
+"unet": [
+ null, null,
+ {
+ "repo": "RunDiffusion/Juggernaut-XL-v9",
+ "subfolder": "unet",
+ "variant": "fp16"
+ }
+]
+```
+
+### 组件加载状态
+
+下面的管道属性提供了关于哪些组件被加载的更多信息。
+
+使用`component_names`返回所有预期的组件。
+
+```py
+t2i_pipeline.component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
+```
+
+使用`null_component_names`返回尚未加载的组件。使用[`~ModularPipeline.from_pretrained`]加载这些组件。
+
+```py
+t2i_pipeline.null_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
+```
+
+使用`pretrained_component_names`返回将从预训练模型加载的组件。
+
+```py
+t2i_pipeline.pretrained_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
+```
+
+使用 `config_component_names` 返回那些使用默认配置创建的组件(不是从模块化仓库加载的)。来自配置的组件不包括在内,因为它们已经在管道创建期间初始化。这就是为什么它们没有列在 `null_component_names` 中。
+
+```py
+t2i_pipeline.config_component_names
+['guider', 'image_processor']
+```
+
+## 更新组件
+
+根据组件是*预训练组件*还是*配置组件*,组件可能会被更新。
+
+> [!WARNING]
+> 在更新组件时,组件可能会从预训练变为配置。组件类型最初是在块的 `expected_components` 字段中定义的。
+
+预训练组件通过 [`ComponentSpec`] 更新,而配置组件则通过直接传递对象或使用 [`ComponentSpec`] 更新。
+
+[`ComponentSpec`] 对于预训练组件显示 `default_creation_method="from_pretrained"`,对于配置组件显示 `default_creation_method="from_config`。
+
+要更新预训练组件,创建一个 [`ComponentSpec`],指定组件的名称和从哪里加载它。使用 [`~ComponentSpec.load`] 方法来加载组件。
+
+```py
+from diffusers import ComponentSpec, UNet2DConditionModel
+
+unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
+unet = unet_spec.load(torch_dtype=torch.float16)
+```
+
+[`~ModularPipeline.update_components`] 方法用一个新的组件替换原来的组件。
+
+```py
+t2i_pipeline.update_components(unet=unet2)
+```
+
+当组件被更新时,加载规范也会在管道配置中更新。
+
+### 组件提取和修改
+
+当你使用 [`~ComponentSpec.load`] 时,新组件保持其加载规范。这使得提取规范并重新创建组件成为可能。
+
+```py
+spec = ComponentSpec.from_component("unet", unet2)
+spec
+ComponentSpec(name='unet', type_hint=, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
+unet2_recreated = spec.load(torch_dtype=torch.float16)
+```
+
+[`~ModularPipeline.get_component_spec`] 方法获取当前组件规范的副本以进行修改或更新。
+
+```py
+unet_spec = t2i_pipeline.get_component_spec("unet")
+unet_spec
+ComponentSpec(
+ name='unet',
+ type_hint=,
+ pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
+ subfolder='unet',
+ variant='fp16',
+ default_creation_method='from_pretrained'
+)
+
+# 修改以从不同的仓库加载
+unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
+
+# 使用修改后的规范加载组件
+unet = unet_spec.load(torch_dtype=torch.float16)
+```
+
+## 模块化仓库
+一个仓库
+如果管道块使用*预训练组件*,则需要y。该存储库提供了加载规范和元数据。
+
+[`ModularPipeline`]特别需要*模块化存储库*(参见[示例存储库](https://huggingface.co/YiYiXu/modular-diffdiff)),这比典型的存储库更灵活。它包含一个`modular_model_index.json`文件,包含以下3个元素。
+
+- `library`和`class`显示组件是从哪个库加载的及其类。如果是`null`,则表示组件尚未加载。
+- `loading_specs_dict`包含加载组件所需的信息,例如从中加载的存储库和子文件夹。
+
+与标准存储库不同,模块化存储库可以根据`loading_specs_dict`从不同的存储库获取组件。组件不需要存在于同一个存储库中。
+
+模块化存储库可能包含用于加载[`ModularPipeline`]的自定义代码。这允许您使用不是Diffusers原生的专用块。
+
+```
+modular-diffdiff-0704/
+├── block.py # 自定义管道块实现
+├── config.json # 管道配置和auto_map
+└── modular_model_index.json # 组件加载规范
+```
+
+[config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json)文件包含一个`auto_map`键,指向`block.py`中定义自定义块的位置。
+
+```json
+{
+ "_class_name": "DiffDiffBlocks",
+ "auto_map": {
+ "ModularPipelineBlocks": "block.DiffDiffBlocks"
+ }
+}
+```
diff --git a/docs/source/zh/modular_diffusers/overview.md b/docs/source/zh/modular_diffusers/overview.md
new file mode 100644
index 000000000000..07021cad2757
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/overview.md
@@ -0,0 +1,38 @@
+
+
+# 概述
+
+> [!WARNING]
+> 模块化Diffusers正在积极开发中,其API可能会发生变化。
+
+模块化Diffusers是一个统一的管道系统,通过*管道块*简化您的工作流程。
+
+- 块是可重用的,您只需要为您的管道创建独特的块。
+- 块可以混合搭配,以适应或为特定工作流程或多个工作流程创建管道。
+
+模块化Diffusers文档的组织如下所示。
+
+## 快速开始
+
+- 一个[快速开始](./quickstart)演示了如何使用模块化Diffusers实现一个示例工作流程。
+
+## ModularPipelineBlocks
+
+- [States](./modular_diffusers_states)解释了数据如何在块和[`ModularPipeline`]之间共享和通信。
+- [ModularPipelineBlocks](./pipeline_block)是[`ModularPipeline`]最基本的单位,本指南向您展示如何创建一个。
+- [SequentialPipelineBlocks](./sequential_pipeline_blocks)是一种类型的块,它将多个块链接起来,使它们一个接一个地运行,沿着链传递数据。本指南向您展示如何创建[`~modular_pipelines.SequentialPipelineBlocks`]以及它们如何连接和一起工作。
+- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks)是一种类型的块,它在循环中运行一系列块。本指南向您展示如何创建[`~modular_pipelines.LoopSequentialPipelineBlocks`]。
+- [AutoPipelineBlocks](./auto_pipeline_blocks)是一种类型的块,它根据输入自动选择要运行的块。本指南向您展示如何创建[`~modular_pipelines.AutoPipelineBlocks`]。
+
+## ModularPipeline
+
+- [ModularPipeline](./modular_pipeline)向您展示如何创建并将管道块转换为可执行的[`ModularPipeline`]。
+- [ComponentsManager](./components_manager)向您展示如何跨多个管道管理和重用组件。
+- [Guiders](./guiders)向您展示如何在管道中使用不同的指导方法。
diff --git a/docs/source/zh/modular_diffusers/pipeline_block.md b/docs/source/zh/modular_diffusers/pipeline_block.md
new file mode 100644
index 000000000000..b3ed807b232b
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/pipeline_block.md
@@ -0,0 +1,114 @@
+
+
+# ModularPipelineBlocks
+
+[`~modular_pipelines.ModularPipelineBlocks`] 是构建 [`ModularPipeline`] 的基本块。它定义了管道中特定步骤应执行的组件、输入/输出和计算。一个 [`~modular_pipelines.ModularPipelineBlocks`] 与其他块连接,使用 [状态](./modular_diffusers_states),以实现工作流的模块化构建。
+
+单独的 [`~modular_pipelines.ModularPipelineBlocks`] 无法执行。它是管道中步骤应执行的操作的蓝图。要实际运行和执行管道,需要将 [`~modular_pipelines.ModularPipelineBlocks`] 转换为 [`ModularPipeline`]。
+
+本指南将向您展示如何创建 [`~modular_pipelines.ModularPipelineBlocks`]。
+
+## 输入和输出
+
+> [!TIP]
+> 如果您不熟悉Modular Diffusers中状态的工作原理,请参考 [States](./modular_diffusers_states) 指南。
+
+一个 [`~modular_pipelines.ModularPipelineBlocks`] 需要 `inputs` 和 `intermediate_outputs`。
+
+- `inputs` 是由用户提供并从 [`~modular_pipelines.PipelineState`] 中检索的值。这很有用,因为某些工作流会调整图像大小,但仍需要原始图像。 [`~modular_pipelines.PipelineState`] 维护原始图像。
+
+ 使用 `InputParam` 定义 `inputs`。
+
+ ```py
+ from diffusers.modular_pipelines import InputParam
+
+ user_inputs = [
+ InputParam(name="image", type_hint="PIL.Image", description="要处理的原始输入图像")
+ ]
+ ```
+
+- `intermediate_inputs` 通常由前一个块创建的值,但如果前面的块没有生成它们,也可以直接提供。与 `inputs` 不同,`intermediate_inputs` 可以被修改。
+
+ 使用 `InputParam` 定义 `intermediate_inputs`。
+
+ ```py
+ user_intermediate_inputs = [
+ InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
+ ]
+ ```
+
+- `intermediate_outputs` 是由块创建并添加到 [`~modular_pipelines.PipelineState`] 的新值。`intermediate_outputs` 可作为后续块的 `intermediate_inputs` 使用,或作为运行管道的最终输出使用。
+
+ 使用 `OutputParam` 定义 `intermediate_outputs`。
+
+ ```py
+ from diffusers.modular_pipelines import OutputParam
+
+ user_intermediate_outputs = [
+ OutputParam(name="image_latents", description="latents representing the image")
+ ]
+ ```
+
+中间输入和输出共享数据以连接块。它们可以在任何时候访问,允许你跟踪工作流的进度。
+
+## 计算逻辑
+
+一个块执行的计算在`__call__`方法中定义,它遵循特定的结构。
+
+1. 检索[`~modular_pipelines.BlockState`]以获取`inputs`和`intermediate_inputs`的局部视图。
+2. 在`inputs`和`intermediate_inputs`上实现计算逻辑。
+3. 更新[`~modular_pipelines.PipelineState`]以将局部[`~modular_pipelines.BlockState`]的更改推送回全局[`~modular_pipelines.PipelineState`]。
+4. 返回对下一个块可用的组件和状态。
+
+```py
+def __call__(self, components, state):
+ # 获取该块需要的状态变量的局部视图
+ block_state = self.get_block_state(state)
+
+ # 你的计算逻辑在这里
+ # block_state包含你所有的inputs和intermediate_inputs
+ # 像这样访问它们: block_state.image, block_state.processed_image
+
+ # 用你更新的block_states更新管道状态
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+### 组件和配置
+
+块需要的组件和管道级别的配置在[`ComponentSpec`]和[`~modular_pipelines.ConfigSpec`]中指定。
+
+- [`ComponentSpec`]包含块使用的预期组件。你需要组件的`name`和理想情况下指定组件确切是什么的`type_hint`。
+- [`~modular_pipelines.ConfigSpec`]包含控制所有块行为的管道级别设置。
+
+```py
+from diffusers import ComponentSpec, ConfigSpec
+
+expected_components = [
+ ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
+ ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
+]
+
+expected_config = [
+ ConfigSpec("force_zeros_for_empty_prompt", True)
+]
+```
+
+当块被转换为管道时,组件作为`__call__`中的第一个参数对块可用。
+
+```py
+def __call__(self, components, state):
+ # 使用点符号访问组件
+ unet = components.unet
+ vae = components.vae
+ scheduler = components.scheduler
+```
diff --git a/docs/source/zh/modular_diffusers/quickstart.md b/docs/source/zh/modular_diffusers/quickstart.md
new file mode 100644
index 000000000000..2c4a6a51afde
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/quickstart.md
@@ -0,0 +1,346 @@
+
+
+# 快速入门
+
+模块化Diffusers是一个快速构建灵活和可定制管道的框架。模块化Diffusers的核心是[`ModularPipelineBlocks`],可以与其他块组合以适应新的工作流程。这些块被转换为[`ModularPipeline`],一个开发者可以使用的友好用户界面。
+
+本文档将向您展示如何使用模块化框架实现[Differential Diffusion](https://differential-diffusion.github.io/)管道。
+
+## ModularPipelineBlocks
+
+[`ModularPipelineBlocks`]是*定义*,指定管道中单个步骤的组件、输入、输出和计算逻辑。有四种类型的块。
+
+- [`ModularPipelineBlocks`]是最基本的单一步骤块。
+- [`SequentialPipelineBlocks`]是一个多块,线性组合其他块。一个块的输出是下一个块的输入。
+- [`LoopSequentialPipelineBlocks`]是一个多块,迭代运行,专为迭代工作流程设计。
+- [`AutoPipelineBlocks`]是一个针对不同工作流程的块集合,它根据输入选择运行哪个块。它旨在方便地将多个工作流程打包到单个管道中。
+
+[Differential Diffusion](https://differential-diffusion.github.io/)是一个图像到图像的工作流程。从`IMAGE2IMAGE_BLOCKS`预设开始,这是一个用于图像到图像生成的`ModularPipelineBlocks`集合。
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
+IMAGE2IMAGE_BLOCKS = InsertableDict([
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("image_encoder", StableDiffusionXLVaeEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLDenoiseStep),
+ ("decode", StableDiffusionXLDecodeStep)
+])
+```
+
+## 管道和块状态
+
+模块化Diffusers使用*状态*在块之间通信数据。有两种类型的状态。
+
+- [`PipelineState`]是一个全局状态,可用于跟踪所有块的所有输入和输出。
+- [`BlockState`]是[`PipelineState`]中相关变量的局部视图,用于单个块。
+
+## 自定义块
+
+[Differential Diffusion](https://differential-diffusion.github.io/) 与标准的图像到图像转换在其 `prepare_latents` 和 `denoise` 块上有所不同。所有其他块都可以重用,但你需要修改这两个。
+
+通过复制和修改现有的块,为 `prepare_latents` 和 `denoise` 创建占位符 `ModularPipelineBlocks`。
+
+打印 `denoise` 块,可以看到它由 [`LoopSequentialPipelineBlocks`] 组成,包含三个子块,`before_denoiser`、`denoiser` 和 `after_denoiser`。只需要修改 `before_denoiser` 子块,根据变化图为去噪器准备潜在输入。
+
+```py
+denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
+print(denoise_blocks)
+```
+
+用新的 `SDXLDiffDiffLoopBeforeDenoiser` 块替换 `StableDiffusionXLLoopBeforeDenoiser` 子块。
+
+```py
+# 复制现有块作为占位符
+class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
+ """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
+ # ... 与 StableDiffusionXLImg2ImgPrepareLatentsStep 相同的实现
+
+class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+```
+
+### prepare_latents
+
+`prepare_latents` 块需要进行以下更改。
+
+- 一个处理器来处理变化图
+- 一个新的 `inputs` 来接受用户提供的变化图,`timestep` 用于预计算所有潜在变量和 `num_inference_steps` 来创建更新图像区域的掩码
+- 更新 `__call__` 方法中的计算,用于处理变化图和创建掩码,并将其存储在 [`BlockState`] 中
+
+```diff
+class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
++ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
+ ]
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("generator"),
++ InputParam("diffdiff_map", required=True),
+- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
++ InputParam("timesteps", type_hint=torch.Tensor),
++ InputParam("num_inference_steps", type_hint=int),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
++ OutputParam("original_latents", type_hint=torch.Tensor),
++ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
+ ]
+ def __call__(self, components, state: PipelineState):
+ # ... existing logic ...
++ # Process change map and create masks
++ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
++ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
++ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
++ block_state.original_latents = block_state.latents
+```
+
+### 去噪
+
+`before_denoiser` 子块需要进行以下更改。
+
+- 新的 `inputs` 以接受 `denoising_start` 参数,`original_latents` 和 `diffdiff_masks` 来自 `prepare_latents` 块
+- 更新 `__call__` 方法中的计算以应用 Differential Diffusion
+
+```diff
+class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ return [
+ InputParam("latents", required=True, type_hint=torch.Tensor),
++ InputParam("denoising_start"),
++ InputParam("original_latents", type_hint=torch.Tensor),
++ InputParam("diffdiff_masks", type_hint=torch.Tensor),
+ ]
+
+ def __call__(self, components, block_state, i, t):
++ # Apply differential diffusion logic
++ if i == 0 and block_state.denoising_start is None:
++ block_state.latents = block_state.original_latents[:1]
++ else:
++ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
++ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
+
+ # ... rest of existing logic ...
+```
+
+## 组装块
+
+此时,您应该拥有创建 [`ModularPipeline`] 所需的所有块。
+
+复制现有的 `IMAGE2IMAGE_BLOCKS` 预设,对于 `set_timesteps` 块,使用 `TEXT2IMAGE_BLOCKS` 中的 `set_timesteps`,因为 Differential Diffusion 不需要 `strength` 参数。
+
+将 `prepare_latents` 和 `denoise` 块设置为您刚刚修改的 `SDXLDiffDiffPrepareLatentsStep` 和 `SDXLDiffDiffDenoiseStep` 块。
+
+调用 [`SequentialPipelineBlocks.from_blocks_dict`] 在块上创建一个 `SequentialPipelineBlocks`。
+
+```py
+DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
+DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
+DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
+DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
+
+dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
+print(dd_blocks)
+```
+
+## ModularPipeline
+
+将 [`SequentialPipelineBlocks`] 转换为 [`ModularPipeline`],使用 [`ModularPipeline.init_pipeline`] 方法。这会初始化从 `modular_model_index.json` 文件加载的预期组件。通过调用 [`ModularPipeline.load_defau
+lt_components`]。
+
+初始化[`ComponentManager`]时传入pipeline是一个好主意,以帮助管理不同的组件。一旦调用[`~ModularPipeline.load_components`],组件就会被注册到[`ComponentManager`]中,并且可以在工作流之间共享。下面的例子使用`collection`参数为组件分配了一个`"diffdiff"`标签,以便更好地组织。
+
+```py
+from diffusers.modular_pipelines import ComponentsManager
+
+components = ComponentManager()
+
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
+dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
+dd_pipeline.to("cuda")
+```
+
+## 添加工作流
+
+可以向[`ModularPipeline`]添加其他工作流以支持更多功能,而无需从头重写整个pipeline。
+
+本节演示如何添加IP-Adapter或ControlNet。
+
+### IP-Adapter
+
+Stable Diffusion XL已经有一个预设的IP-Adapter块,你可以使用,并且不需要对现有的Differential Diffusion pipeline进行任何更改。
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
+
+ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
+```
+
+使用[`sub_blocks.insert`]方法将其插入到[`ModularPipeline`]中。下面的例子在位置`0`插入了`ip_adapter_block`。打印pipeline可以看到`ip_adapter_block`被添加了,并且它需要一个`ip_adapter_image`。这也向pipeline添加了两个组件,`image_encoder`和`feature_extractor`。
+
+```py
+dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
+```
+
+调用[`~ModularPipeline.init_pipeline`]来初始化一个[`ModularPipeline`],并使用[`~ModularPipeline.load_components`]加载模型组件。加载并设置IP-Adapter以运行pipeline。
+
+```py
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
+dd_pipeline.loader.set_ip_adapter_scale(0.6)
+dd_pipeline = dd_pipeline.to(device)
+
+ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
+image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+
+prompt = "a green pear"
+negative_prompt = "blurry"
+generator = torch.Generator(device=device).manual_seed(42)
+
+image = dd_pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=25,
+ generator=generator,
+ ip_adapter_image=ip_adapter_image,
+ diffdiff_map=mask,
+ image=image,
+
+output="images"
+)[0]
+```
+
+### ControlNet
+
+Stable Diffusion XL 已经预设了一个可以立即使用的 ControlNet 块。
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
+
+control_input_block = StableDiffusionXLAutoControlNetInputStep()
+```
+
+然而,它需要修改 `denoise` 块,因为那是 ControlNet 将控制信息注入到 UNet 的地方。
+
+通过将 `StableDiffusionXLLoopDenoiser` 子块替换为 `StableDiffusionXLControlNetLoopDenoiser` 来修改 `denoise` 块。
+
+```py
+class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
+```
+
+插入 `controlnet_input` 块并用新的 `controlnet_denoise_block` 替换 `denoise` 块。初始化一个 [`ModularPipeline`] 并将 [`~ModularPipeline.load_components`] 加载到其中。
+
+```py
+dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
+dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
+
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+dd_pipeline = dd_pipeline.to(device)
+
+control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
+image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+
+prompt = "a green pear"
+negative_prompt = "blurry"
+generator = torch.Generator(device=device).manual_seed(42)
+
+image = dd_pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=25,
+ generator=generator,
+ control_image=control_image,
+ controlnet_conditioning_scale=0.5,
+ diffdiff_map=mask,
+ image=image,
+ output="images"
+)[0]
+```
+
+### AutoPipelineBlocks
+
+差分扩散、IP-Adapter 和 ControlNet 工作流可以通过使用 [`AutoPipelineBlocks`] 捆绑到一个单一的 [`ModularPipeline`] 中。这允许根据输入如 `control_image` 或 `ip_adapter_image` 自动选择要运行的子块。如果没有传递这些输入,则默认为差分扩散。
+
+使用 `block_trigger_inputs` 仅在提供 `control_image` 输入时运行 `SDXLDiffDiffControlNetDenoiseStep` 块。否则,使用 `SDXLDiffDiffDenoiseStep`。
+
+```py
+class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
+ block_names = ["contr
+olnet_denoise", "denoise"]
+block_trigger_inputs = ["controlnet_cond", None]
+```
+
+添加 `ip_adapter` 和 `controlnet_input` 块。
+
+```py
+DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
+DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
+DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
+DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
+DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
+DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
+```
+
+调用 [`SequentialPipelineBlocks.from_blocks_dict`] 来创建一个 [`SequentialPipelineBlocks`] 并创建一个 [`ModularPipeline`] 并加载模型组件以运行。
+
+```py
+dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
+dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+```
+
+## 分享
+
+使用 [`~ModularPipeline.save_pretrained`] 将您的 [`ModularPipeline`] 添加到 Hub,并将 `push_to_hub` 参数设置为 `True`。
+
+```py
+dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
+```
+
+其他用户可以使用 [`~ModularPipeline.from_pretrained`] 加载 [`ModularPipeline`]。
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+
+diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
+diffdiff_pipeline.load_components(torch_dtype=torch.float16)
+```
diff --git a/docs/source/zh/modular_diffusers/sequential_pipeline_blocks.md b/docs/source/zh/modular_diffusers/sequential_pipeline_blocks.md
new file mode 100644
index 000000000000..befb81f85ddf
--- /dev/null
+++ b/docs/source/zh/modular_diffusers/sequential_pipeline_blocks.md
@@ -0,0 +1,112 @@
+
+
+# 顺序管道块
+
+[`~modular_pipelines.SequentialPipelineBlocks`] 是一种多块类型,它将其他 [`~modular_pipelines.ModularPipelineBlocks`] 按顺序组合在一起。数据通过 `intermediate_inputs` 和 `intermediate_outputs` 线性地从一个块流向下一个块。[`~modular_pipelines.SequentialPipelineBlocks`] 中的每个块通常代表管道中的一个步骤,通过组合它们,您逐步构建一个管道。
+
+本指南向您展示如何将两个块连接成一个 [`~modular_pipelines.SequentialPipelineBlocks`]。
+
+创建两个 [`~modular_pipelines.ModularPipelineBlocks`]。第一个块 `InputBlock` 输出一个 `batch_size` 值,第二个块 `ImageEncoderBlock` 使用 `batch_size` 作为 `intermediate_inputs`。
+
+
+
+
+```py
+from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
+
+class InputBlock(ModularPipelineBlocks):
+
+ @property
+ def inputs(self):
+ return [
+ InputParam(name="prompt", type_hint=list, description="list of text prompts"),
+ InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt"),
+ ]
+
+ @property
+ def intermediate_outputs(self):
+ return [
+ OutputParam(name="batch_size", description="calculated batch size"),
+ ]
+
+ @property
+ def description(self):
+ return "A block that determines batch_size based on the number of prompts and num_images_per_prompt argument."
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ batch_size = len(block_state.prompt)
+ block_state.batch_size = batch_size * block_state.num_images_per_prompt
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
+
+class ImageEncoderBlock(ModularPipelineBlocks):
+
+ @property
+ def inputs(self):
+ return [
+ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process"),
+ InputParam(name="batch_size", type_hint=int),
+ ]
+
+ @property
+ def intermediate_outputs(self):
+ return [
+ OutputParam(name="image_latents", description="latents representing the image"
+ ]
+
+ @property
+ def description(self):
+ return "Encode raw image into its latent presentation"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ # 模拟处理图像
+ # 这将改变所有块的图像状态,从PIL图像变为张量
+ block_state.image = torch.randn(1, 3, 512, 512)
+ block_state.batch_size = block_state.batch_size * 2
+ block_state.image_latents = torch.randn(1, 4, 64, 64)
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
+
+通过定义一个[`InsertableDict`]来连接两个块,将块名称映射到块实例。块按照它们在`blocks_dict`中注册的顺序执行。
+
+使用[`~modular_pipelines.SequentialPipelineBlocks.from_blocks_dict`]来创建一个[`~modular_pipelines.SequentialPipelineBlocks`]。
+
+```py
+from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
+
+blocks_dict = InsertableDict()
+blocks_dict["input"] = input_block
+blocks_dict["image_encoder"] = image_encoder_block
+
+blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
+```
+
+通过调用`blocks`来检查[`~modular_pipelines.SequentialPipelineBlocks`]中的子块,要获取更多关于输入和输出的详细信息,可以访问`docs`属性。
+
+```py
+print(blocks)
+print(blocks.doc)
+```
diff --git a/docs/source/zh/optimization/cache.md b/docs/source/zh/optimization/cache.md
new file mode 100644
index 000000000000..f7a94de4f11f
--- /dev/null
+++ b/docs/source/zh/optimization/cache.md
@@ -0,0 +1,67 @@
+
+
+# 缓存
+
+缓存通过存储和重用不同层的中间输出(如注意力层和前馈层)来加速推理,而不是在每个推理步骤执行整个计算。它显著提高了生成速度,但以更多内存为代价,并且不需要额外的训练。
+
+本指南向您展示如何在 Diffusers 中使用支持的缓存方法。
+
+## 金字塔注意力广播
+
+[金字塔注意力广播 (PAB)](https://huggingface.co/papers/2408.12588) 基于这样一种观察:在生成过程的连续时间步之间,注意力输出差异不大。注意力差异在交叉注意力层中最小,并且通常在一个较长的时间步范围内被缓存。其次是时间注意力和空间注意力层。
+
+> [!TIP]
+> 并非所有视频模型都有三种类型的注意力(交叉、时间和空间)!
+
+PAB 可以与其他技术(如序列并行性和无分类器引导并行性(数据并行性))结合,实现近乎实时的视频生成。
+
+设置并传递一个 [`PyramidAttentionBroadcastConfig`] 到管道的变换器以启用它。`spatial_attention_block_skip_range` 控制跳过空间注意力块中注意力计算的频率,`spatial_attention_timestep_skip_range` 是要跳过的时间步范围。注意选择一个合适的范围,因为较小的间隔可能导致推理速度变慢,而较大的间隔可能导致生成质量降低。
+
+```python
+import torch
+from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
+
+pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+pipeline.to("cuda")
+
+config = PyramidAttentionBroadcastConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(100, 800),
+ current_timestep_callback=lambda: pipe.current_timestep,
+)
+pipeline.transformer.enable_cache(config)
+```
+
+## FasterCache
+
+[FasterCache](https://huggingface.co/papers/2410.19355) 缓存并重用注意力特征,类似于 [PAB](#pyramid-attention-broadcast),因为每个连续时间步的输出差异很小。
+
+此方法在使用无分类器引导进行采样时(在大多数基础模型中常见),也可能选择跳过无条件分支预测,并且
+如果连续时间步之间的预测潜在输出存在显著冗余,则从条件分支预测中估计它。
+
+设置并将 [`FasterCacheConfig`] 传递给管道的 transformer 以启用它。
+
+```python
+import torch
+from diffusers import CogVideoXPipeline, FasterCacheConfig
+
+pipe line= CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+pipeline.to("cuda")
+
+config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 681),
+ current_timestep_callback=lambda: pipe.current_timestep,
+ attention_weight_callback=lambda _: 0.3,
+ unconditional_batch_skip_range=5,
+ unconditional_batch_timestep_skip_range=(-1, 781),
+ tensor_format="BFCHW",
+)
+pipeline.transformer.enable_cache(config)
+```
\ No newline at end of file
diff --git a/docs/source/zh/optimization/coreml.md b/docs/source/zh/optimization/coreml.md
new file mode 100644
index 000000000000..3926a5ddb029
--- /dev/null
+++ b/docs/source/zh/optimization/coreml.md
@@ -0,0 +1,160 @@
+
+
+# 如何使用 Core ML 运行 Stable Diffusion
+
+[Core ML](https://developer.apple.com/documentation/coreml) 是 Apple 框架支持的模型格式和机器学习库。如果您有兴趣在 macOS 或 iOS/iPadOS 应用中运行 Stable Diffusion 模型,本指南将展示如何将现有的 PyTorch 检查点转换为 Core ML 格式,并使用 Python 或 Swift 进行推理。
+
+Core ML 模型可以利用 Apple 设备中所有可用的计算引擎:CPU、GPU 和 Apple Neural Engine(或 ANE,一种在 Apple Silicon Mac 和现代 iPhone/iPad 中可用的张量优化加速器)。根据模型及其运行的设备,Core ML 还可以混合和匹配计算引擎,例如,模型的某些部分可能在 CPU 上运行,而其他部分在 GPU 上运行。
+
+> [!TIP]
+> 您还可以使用 PyTorch 内置的 `mps` 加速器在 Apple Silicon Mac 上运行 `diffusers` Python 代码库。这种方法在 [mps 指南](mps) 中有详细解释,但它与原生应用不兼容。
+
+## Stable Diffusion Core ML 检查点
+
+Stable Diffusion 权重(或检查点)以 PyTorch 格式存储,因此在使用它们之前,需要将它们转换为 Core ML 格式。
+
+幸运的是,Apple 工程师基于 `diffusers` 开发了 [一个转换工具](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml),用于将 PyTorch 检查点转换为 Core ML。
+
+但在转换模型之前,花点时间探索 Hugging Face Hub——很可能您感兴趣的模型已经以 Core ML 格式提供:
+
+- [Apple](https://huggingface.co/apple) 组织包括 Stable Diffusion 版本 1.4、1.5、2.0 基础和 2.1 基础
+- [coreml community](https://huggingface.co/coreml-community) 包括自定义微调模型
+- 使用此 [过滤器](https://huggingface.co/models?pipeline_tag=text-to-image&library=coreml&p=2&sort=likes) 返回所有可用的 Core ML 检查点
+
+如果您找不到感兴趣的模型,我们建议您遵循 Apple 的 [Converting Models to Core ML](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) 说明。
+
+## 选择要使用的 Core ML 变体
+
+Stable Diffusion 模型可以转换为不同的 Core ML 变体,用于不同目的:
+
+- 注意力类型
+使用了n个块。注意力操作用于“关注”图像表示中不同区域之间的关系,并理解图像和文本表示如何相关。注意力的计算和内存消耗很大,因此存在不同的实现方式,以适应不同设备的硬件特性。对于Core ML Stable Diffusion模型,有两种注意力变体:
+* `split_einsum`([由Apple引入](https://machinelearning.apple.com/research/neural-engine-transformers))针对ANE设备进行了优化,这些设备在现代iPhone、iPad和M系列计算机中可用。
+* “原始”注意力(在`diffusers`中使用的基础实现)仅与CPU/GPU兼容,不与ANE兼容。在CPU + GPU上使用`original`注意力运行模型可能比ANE*更快*。请参阅[此性能基准](https://huggingface.co/blog/fast-mac-diffusers#performance-benchmarks)以及社区提供的[一些额外测量](https://github.com/huggingface/swift-coreml-diffusers/issues/31)以获取更多细节。
+
+- 支持的推理框架。
+* `packages`适用于Python推理。这可用于在尝试将转换后的Core ML模型集成到原生应用程序之前进行测试,或者如果您想探索Core ML性能但不需要支持原生应用程序。例如,具有Web UI的应用程序完全可以使用Python Core ML后端。
+* `compiled`模型是Swift代码所必需的。Hub中的`compiled`模型将大型UNet模型权重分成多个文件,以兼容iOS和iPadOS设备。这对应于[`--chunk-unet`转换选项](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml)。如果您想支持原生应用程序,则需要选择`compiled`变体。
+
+官方的Core ML Stable Diffusion[模型](https://huggingface.co/apple/coreml-stable-diffusion-v1-4/tree/main)包括这些变体,但社区的可能有所不同:
+
+```
+coreml-stable-diffusion-v1-4
+├── README.md
+├── original
+│ ├── compiled
+│ └── packages
+└── split_einsum
+ ├── compiled
+ └── packages
+```
+
+您可以下载并使用所需的变体,如下所示。
+
+## Python中的Core ML推理
+
+安装以下库以在Python中运行Core ML推理:
+
+```bash
+pip install huggingface_hub
+pip install git+https://github.com/apple/ml-stable-diffusion
+```
+
+### 下载模型检查点
+
+要在Python中运行推理,请使用存储在`packages`文件夹中的版本之一,因为`compiled`版本仅与Swift兼容。您可以选择使用`original`或`split_einsum`注意力。
+
+这是您如何从Hub下载`original`注意力变体到一个名为`models`的目录:
+
+```Python
+from huggingface_hub import snapshot_download
+from pathlib import Path
+
+repo_id = "apple/coreml-stable-diffusion-v1-4"
+variant = "original/packages"
+
+mo
+del_path = Path("./models") / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
+snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path, local_dir_use_symlinks=False)
+print(f"Model downloaded at {model_path}")
+```
+
+### 推理[[python-inference]]
+
+下载模型快照后,您可以使用 Apple 的 Python 脚本来测试它。
+
+```shell
+python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" -i ./models/coreml-stable-diffusion-v1-4_original_packages/original/packages -o --compute-unit CPU_AND_GPU --seed 93
+```
+
+使用 `-i` 标志将下载的检查点路径传递给脚本。`--compute-unit` 表示您希望允许用于推理的硬件。它必须是以下选项之一:`ALL`、`CPU_AND_GPU`、`CPU_ONLY`、`CPU_AND_NE`。您也可以提供可选的输出路径和用于可重现性的种子。
+
+推理脚本假设您使用的是 Stable Diffusion 模型的原始版本,`CompVis/stable-diffusion-v1-4`。如果您使用另一个模型,您*必须*在推理命令行中使用 `--model-version` 选项指定其 Hub ID。这适用于已支持的模型以及您自己训练或微调的自定义模型。
+
+例如,如果您想使用 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5):
+
+```shell
+python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" --compute-unit ALL -o output --seed 93 -i models/coreml-stable-diffusion-v1-5_original_packages --model-version stable-diffusion-v1-5/stable-diffusion-v1-5
+```
+
+## Core ML 在 Swift 中的推理
+
+在 Swift 中运行推理比在 Python 中稍快,因为模型已经以 `mlmodelc` 格式编译。这在应用启动时加载模型时很明显,但如果在之后运行多次生成,则不应明显。
+
+### 下载
+
+要在您的 Mac 上运行 Swift 推理,您需要一个 `compiled` 检查点版本。我们建议您使用类似于先前示例的 Python 代码在本地下载它们,但使用 `compiled` 变体之一:
+
+```Python
+from huggingface_hub import snapshot_download
+from pathlib import Path
+
+repo_id = "apple/coreml-stable-diffusion-v1-4"
+variant = "original/compiled"
+
+model_path = Path("./models") / (repo_id.split("/")[-1] + "_" + variant.replace("/", "_"))
+snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path, local_dir_use_symlinks=False)
+print(f"Model downloaded at {model_path}")
+```
+
+### 推理[[swift-inference]]
+
+要运行推理,请克隆 Apple 的仓库:
+
+```bash
+git clone https://github.com/apple/ml-stable-diffusion
+cd ml-stable-diffusion
+```
+
+然后使用 Apple 的命令行工具,[Swift Package Manager](https://www.swift.org/package-manager/#):
+
+```bash
+swift run StableDiffusionSample --resource-path models/coreml-stable-diffusion-v1-4_original_compiled --compute-units all "a photo of an astronaut riding a horse on mars"
+```
+
+您必须在 `--resource-path` 中指定上一步下载的检查点之一,请确保它包含扩展名为 `.mlmodelc` 的已编译 Core ML 包。`--compute-units` 必须是以下值之一:`all`、`cpuOnly`、`cpuAndGPU`、`cpuAndNeuralEngine`。
+
+有关更多详细信息,请参考 [Apple 仓库中的说明](https://github.com/apple/ml-stable-diffusion)。
+
+## 支持的 Diffusers 功能
+
+Core ML 模型和推理代码不支持 🧨 Diffusers 的许多功能、选项和灵活性。以下是一些需要注意的限制:
+
+- Core ML 模型仅适用于推理。它们不能用于训练或微调。
+- 只有两个调度器已移植到 Swift:Stable Diffusion 使用的默认调度器和我们从 `diffusers` 实现移植到 Swift 的 `DPMSolverMultistepScheduler`。我们推荐您使用 `DPMSolverMultistepScheduler`,因为它在约一半的步骤中产生相同的质量。
+- 负面提示、无分类器引导尺度和图像到图像任务在推理代码中可用。高级功能如深度引导、ControlNet 和潜在上采样器尚不可用。
+
+Apple 的 [转换和推理仓库](https://github.com/apple/ml-stable-diffusion) 和我们自己的 [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) 仓库旨在作为技术演示,以帮助其他开发者在此基础上构建。
+
+如果您对任何缺失功能有强烈需求,请随时提交功能请求或更好的是,贡献一个 PR 🙂。
+
+## 原生 Diffusers Swift 应用
+
+一个简单的方法来在您自己的 Apple 硬件上运行 Stable Diffusion 是使用 [我们的开源 Swift 仓库](https://github.com/huggingface/swift-coreml-diffusers),它基于 `diffusers` 和 Apple 的转换和推理仓库。您可以研究代码,使用 [Xcode](https://developer.apple.com/xcode/) 编译它,并根据您的需求进行适配。为了方便,[App Store 中还有一个独立 Mac 应用](https://apps.apple.com/app/diffusers/id1666309574),因此您无需处理代码或 IDE 即可使用它。如果您是开发者,并已确定 Core ML 是构建您的 Stable Diffusion 应用的最佳解决方案,那么您可以使用本指南的其余部分来开始您的项目。我们迫不及待想看看您会构建什么 🙂。
\ No newline at end of file
diff --git a/docs/source/zh/optimization/deepcache.md b/docs/source/zh/optimization/deepcache.md
new file mode 100644
index 000000000000..4f19d4a36528
--- /dev/null
+++ b/docs/source/zh/optimization/deepcache.md
@@ -0,0 +1,59 @@
+
+
+# DeepCache
+[DeepCache](https://huggingface.co/papers/2312.00858) 通过策略性地缓存和重用高级特征,同时利用 U-Net 架构高效更新低级特征,来加速 [`StableDiffusionPipeline`] 和 [`StableDiffusionXLPipeline`]。
+
+首先安装 [DeepCache](https://github.com/horseee/DeepCache):
+```bash
+pip install DeepCache
+```
+
+然后加载并启用 [`DeepCacheSDHelper`](https://github.com/horseee/DeepCache#usage):
+
+```diff
+ import torch
+ from diffusers import StableDiffusionPipeline
+ pipe = StableDiffusionPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', torch_dtype=torch.float16).to("cuda")
+
++ from DeepCache import DeepCacheSDHelper
++ helper = DeepCacheSDHelper(pipe=pipe)
++ helper.set_params(
++ cache_interval=3,
++ cache_branch_id=0,
++ )
++ helper.enable()
+
+ image = pipe("a photo of an astronaut on a moon").images[0]
+```
+
+`set_params` 方法接受两个参数:`cache_interval` 和 `cache_branch_id`。`cache_interval` 表示特征缓存的频率,指定为每次缓存操作之间的步数。`cache_branch_id` 标识网络的哪个分支(从最浅层到最深层排序)负责执行缓存过程。
+选择较低的 `cache_branch_id` 或较大的 `cache_interval` 可以加快推理速度,但会降低图像质量(这些超参数的消融实验可以在[论文](https://huggingface.co/papers/2312.00858)中找到)。一旦设置了这些参数,使用 `enable` 或 `disable` 方法来激活或停用 `DeepCacheSDHelper`。
+
+
+
+
+
+您可以在 [WandB 报告](https://wandb.ai/horseee/DeepCache/runs/jwlsqqgt?workspace=user-horseee) 中找到更多生成的样本(原始管道 vs DeepCache)和相应的推理延迟。提示是从 [MS-COCO 2017](https://cocodataset.org/#home) 数据集中随机选择的。
+
+## 基准测试
+
+我们在 NVIDIA RTX A5000 上测试了 DeepCache 使用 50 个推理步骤加速 [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) 的速度,使用不同的配置,包括分辨率、批处理大小、缓存间隔(I)和缓存分支(B)。
+
+| **分辨率** | **批次大小** | **原始** | **DeepCache(I=3, B=0)** | **DeepCache(I=5, B=0)** | **DeepCache(I=5, B=1)** |
+|----------------|----------------|--------------|-------------------------|-------------------------|-------------------------|
+| 512| 8| 15.96| 6.88(2.32倍)| 5.03(3.18倍)| 7.27(2.20x)|
+| | 4| 8.39| 3.60(2.33倍)| 2.62(3.21倍)| 3.75(2.24x)|
+| | 1| 2.61| 1.12(2.33倍)| 0.81(3.24倍)| 1.11(2.35x)|
+| 768| 8| 43.58| 18.99(2.29倍)| 13.96(3.12倍)| 21.27(2.05x)|
+| | 4| 22.24| 9.67(2.30倍)| 7.10(3.13倍)| 10.74(2.07x)|
+| | 1| 6.33| 2.72(2.33倍)| 1.97(3.21倍)| 2.98(2.12x)|
+| 1024| 8| 101.95| 45.57(2.24倍)| 33.72(3.02倍)| 53.00(1.92x)|
+| | 4| 49.25| 21.86(2.25倍)| 16.19(3.04倍)| 25.78(1.91x)|
+| | 1| 13.83| 6.07(2.28倍)| 4.43(3.12倍)| 7.15(1.93x)|
\ No newline at end of file
diff --git a/docs/source/zh/optimization/fp16.md b/docs/source/zh/optimization/fp16.md
new file mode 100644
index 000000000000..e1c4c7e57ae7
--- /dev/null
+++ b/docs/source/zh/optimization/fp16.md
@@ -0,0 +1,304 @@
+
+
+# 加速推理
+
+Diffusion模型在推理时速度较慢,因为生成是一个迭代过程,需要经过一定数量的"步数"逐步将噪声细化为图像或视频。要加速这一过程,您可以尝试使用不同的[调度器](../api/schedulers/overview)、降低模型权重的精度以加快计算、使用更高效的内存注意力机制等方法。
+
+将这些技术组合使用,可以比单独使用任何一种技术获得更快的推理速度。
+
+本指南将介绍如何加速推理。
+
+## 模型数据类型
+
+模型权重的精度和数据类型会影响推理速度,因为更高的精度需要更多内存来加载,也需要更多时间进行计算。PyTorch默认以float32或全精度加载模型权重,因此更改数据类型是快速获得更快推理速度的简单方法。
+
+
+
+
+bfloat16与float16类似,但对数值误差更稳健。硬件对bfloat16的支持各不相同,但大多数现代GPU都能支持bfloat16。
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
+```
+
+
+
+
+float16与bfloat16类似,但可能更容易出现数值误差。
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
+```
+
+
+
+
+[TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)模式在NVIDIA Ampere GPU上受支持,它以tf32计算卷积和矩阵乘法运算。存储和其他操作保持在float32。与bfloat16或float16结合使用时,可以显著加快计算速度。
+
+PyTorch默认仅对卷积启用tf32模式,您需要显式启用矩阵乘法的tf32模式。
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+torch.backends.cuda.matmul.allow_tf32 = True
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
+```
+
+更多详情请参阅[混合精度训练](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#mixed-precision)文档。
+
+
+
+
+## 缩放点积注意力
+
+> [!TIP]
+> 内存高效注意力优化了推理速度*和*[内存使用](./memory#memory-efficient-attention)!
+
+[缩放点积注意力(SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)实现了多种注意力后端,包括[FlashAttention](https://github.com/Dao-AILab/flash-attention)、[xFormers](https://github.com/facebookresearch/xformers)和原生C++实现。它会根据您的硬件自动选择最优的后端。
+
+如果您使用的是PyTorch >= 2.0,SDPA默认启用,无需对代码进行任何额外更改。不过,您也可以尝试使用其他注意力后端来自行选择。下面的示例使用[torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html)上下文管理器来启用高效注意力。
+
+```py
+from torch.nn.attention import SDPBackend, sdpa_kernel
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
+).to("cuda")
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+
+with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
+ image = pipeline(prompt, num_inference_steps=30).images[0]
+```
+
+## torch.compile
+
+[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)通过将PyTorch代码和操作编译为优化的内核来加速推理。Diffusers通常会编译计算密集型的模型,如UNet、transformer或VAE。
+
+启用以下编译器设置以获得最大速度(更多选项请参阅[完整列表](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py))。
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+torch._inductor.config.conv_1x1_as_mm = True
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.epilogue_fusion = False
+torch._inductor.config.coordinate_descent_check_all_directions = True
+```
+
+加载并编译UNet和VAE。有几种不同的模式可供选择,但`"max-autotune"`通过编译为CUDA图来优化速度。CUDA图通过单个CPU操作启动多个GPU操作,有效减少了开销。
+
+> [!TIP]
+> 在PyTorch 2.3.1中,您可以控制torch.compile的缓存行为。这对于像`"max-autotune"`这样的编译模式特别有用,它会通过网格搜索多个编译标志来找到最优配置。更多详情请参阅[torch.compile中的编译时间缓存](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)教程。
+
+将内存布局更改为[channels_last](./memory#torchchannels_last)也可以优化内存和推理速度。
+
+```py
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+).to("cuda")
+pipeline.unet.to(memory_format=torch.channels_last)
+pipeline.vae.to(memory_format=torch.channels_last)
+pipeline.unet = torch.compile(
+ pipeline.unet, mode="max-autotune", fullgraph=True
+)
+pipeline.vae.decode = torch.compile(
+ pipeline.vae.decode,
+ mode="max-autotune",
+ fullgraph=True
+)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
+```
+
+第一次编译时速度较慢,但一旦编译完成,速度会显著提升。尽量只在相同类型的推理操作上使用编译后的管道。在不同尺寸的图像上调用编译后的管道会重新触发编译,这会很慢且效率低下。
+
+### 动态形状编译
+
+> [!TIP]
+> 确保始终使用PyTorch的nightly版本以获得更好的支持。
+
+`torch.compile`会跟踪输入形状和条件,如果这些不同,它会重新编译模型。例如,如果模型是在1024x1024分辨率的图像上编译的,而在不同分辨率的图像上使用,就会触发重新编译。
+
+为避免重新编译,添加`dynamic=True`以尝试生成更动态的内核,避免条件变化时重新编译。
+
+```diff
++ torch.fx.experimental._config.use_duck_shape = False
++ pipeline.unet = torch.compile(
+ pipeline.unet, fullgraph=True, dynamic=True
+)
+```
+
+指定`use_duck_shape=False`会指示编译器是否应使用相同的符号变量来表示相同大小的输入。更多详情请参阅此[评论](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790)。
+
+并非所有模型都能开箱即用地从动态编译中受益,可能需要更改。参考此[PR](https://github.com/huggingface/diffusers/pull/11297/),它改进了[`AuraFlowPipeline`]的实现以受益于动态编译。
+
+如果动态编译对Diffusers模型的效果不如预期,请随时提出问题。
+
+### 区域编译
+
+[区域编译](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)通过仅编译模型中*小而频繁重复的块*(通常是transformer层)来减少冷启动延迟,并为每个后续出现的块重用编译后的工件。对于许多diffusion架构,这提供了与全图编译相同的运行时加速,并将编译时间减少了8-10倍。
+
+使用[`~ModelMixin.compile_repeated_blocks`]方法(一个包装`torch.compile`的辅助函数)在任何组件(如transformer模型)上,如下所示。
+
+```py
+# pip install -U diffusers
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+# 仅编译UNet中重复的transformer层
+pipeline.unet.compile_repeated_blocks(fullgraph=True)
+```
+
+要为新模型启用区域编译,请在模型类中添加一个`_repeated_blocks`属性,包含您想要编译的块的类名(作为字符串)。
+
+```py
+class MyUNet(ModelMixin):
+ _repeated_blocks = ("Transformer2DModel",) # ← 默认编译
+```
+
+> [!TIP]
+> 更多区域编译示例,请参阅参考[PR](https://github.com/huggingface/diffusers/pull/11705)。
+
+[Accelerate](https://huggingface.co/docs/accelerate/index)中还有一个[compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78)方法,可以自动选择模型中的候选块进行编译。其余图会单独编译。这对于快速实验很有用,因为您不需要设置哪些块要编译或调整编译标志。
+
+```py
+# pip install -U accelerate
+import torch
+from diffusers import StableDiffusionXLPipeline
+from accelerate.utils import compile regions
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+).to("cuda")
+pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
+```
+
+[`~ModelMixin.compile_repeated_blocks`]是故意显式的。在`_repeated_blocks`中列出要重复的块,辅助函数仅编译这些块。它提供了可预测的行为,并且只需一行代码即可轻松推理缓存重用。
+
+### 图中断
+
+在torch.compile中指定`fullgraph=True`非常重要,以确保底层模型中没有图中断。这使您可以充分利用torch.compile而不会降低性能。对于UNet和VAE,这会改变您访问返回变量的方式。
+
+```diff
+- latents = unet(
+- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
+-).sample
+
++ latents = unet(
++ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
++)[0]
+```
+
+### GPU同步
+
+每次去噪器做出预测后,调度器的`step()`函数会被[调用](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228),并且`sigmas`变量会被[索引](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476)。当放在GPU上时,这会引入延迟,因为CPU和GPU之间需要进行通信同步。当去噪器已经编译时,这一点会更加明显。
+
+一般来说,`sigmas`应该[保持在CPU上](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240),以避免通信同步和延迟。
+
+> [!TIP]
+> 参阅[torch.compile和Diffusers:峰值性能实践指南](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)博客文章,了解如何为扩散模型最大化`torch.compile`的性能。
+
+### 基准测试
+
+参阅[diffusers/benchmarks](https://huggingface.co/datasets/diffusers/benchmarks)数据集,查看编译管道的推理延迟和内存使用数据。
+
+[diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results)仓库还包含Flux和CogVideoX编译版本的基准测试结果。
+
+## 动态量化
+
+[动态量化](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html)通过降低精度以加快数学运算来提高推理速度。这种特定类型的量化在运行时根据数据确定如何缩放激活,而不是使用固定的缩放因子。因此,缩放因子与数据更准确地匹配。
+
+以下示例使用[torchao](../quantization/torchao)库对UNet和VAE应用[动态int8量化](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html)。
+
+> [!TIP]
+> 参阅我们的[torchao](../quantization/torchao)文档,了解更多关于如何使用Diffusers torchao集成的信息。
+
+配置编译器标志以获得最大速度。
+
+```py
+import torch
+from torchao import apply_dynamic_quant
+from diffusers import StableDiffusionXLPipeline
+
+torch._inductor.config.conv_1x1_as_mm = True
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.epilogue_fusion = False
+torch._inductor.config.coordinate_descent_check_all_directions = True
+torch._inductor.config.force_fuse_int_mm_with_mul = True
+torch._inductor.config.use_mixed_mm = True
+```
+
+使用[dynamic_quant_filter_fn](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16)过滤掉UNet和VAE中一些不会从动态量化中受益的线性层。
+
+```py
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
+).to("cuda")
+
+apply_dynamic_quant(pipeline.unet, dynamic_quant_filter_fn)
+apply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)
+
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, num_inference_steps=30).images[0]
+```
+
+## 融合投影矩阵
+
+> [!WARNING]
+> [fuse_qkv_projections](https://github.com/huggingface/diffusers/blob/58431f102cf39c3c8a569f32d71b2ea8caa461e1/src/diffusers/pipelines/pipeline_utils.py#L2034)方法是实验性的,目前主要支持Stable Diffusion管道。参阅此[PR](https://github.com/huggingface/diffusers/pull/6179)了解如何为其他管道启用它。
+
+在注意力块中,输入被投影到三个子空间,分别由投影矩阵Q、K和V表示。这些投影通常单独计算,但您可以水平组合这些矩阵为一个矩阵,并在单步中执行投影。这会增加输入投影的矩阵乘法大小,并提高量化的效果。
+
+```py
+pipeline.fuse_qkv_projections()
+```
+
+## 资源
+
+- 阅读[Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/)博客文章,了解如何结合所有这些优化与[TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html)和[AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html),使用[flux-fast](https://github.com/huggingface/flux-fast)的配方获得约2.5倍的加速。
+
+ 这些配方支持AMD硬件和[Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev)。
+- 阅读[torch.compile和Diffusers:峰值性能实践指南](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)博客文章,了解如何在使用`torch.compile`时最大化性能。
diff --git a/docs/source/zh/optimization/habana.md b/docs/source/zh/optimization/habana.md
new file mode 100644
index 000000000000..9b15847d63f4
--- /dev/null
+++ b/docs/source/zh/optimization/habana.md
@@ -0,0 +1,28 @@
+
+
+# Intel Gaudi
+
+Intel Gaudi AI 加速器系列包括 [Intel Gaudi 1](https://habana.ai/products/gaudi/)、[Intel Gaudi 2](https://habana.ai/products/gaudi2/) 和 [Intel Gaudi 3](https://habana.ai/products/gaudi3/)。每台服务器配备 8 个设备,称为 Habana 处理单元 (HPU),在 Gaudi 3 上提供 128GB 内存,在 Gaudi 2 上提供 96GB 内存,在第一代 Gaudi 上提供 32GB 内存。有关底层硬件架构的更多详细信息,请查看 [Gaudi 架构](https://docs.habana.ai/en/latest/Gaudi_Overview/Gaudi_Architecture.html) 概述。
+
+Diffusers 管道可以利用 HPU 加速,即使管道尚未添加到 [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index),也可以通过 [GPU 迁移工具包](https://docs.habana.ai/en/latest/PyTorch/PyTorch_Model_Porting/GPU_Migration_Toolkit/GPU_Migration_Toolkit.html) 实现。
+
+在您的管道上调用 `.to("hpu")` 以将其移动到 HPU 设备,如下所示为 Flux 示例:
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
+pipeline.to("hpu")
+
+image = pipeline("一张松鼠在毕加索风格中的图像").images[0]
+```
+
+> [!TIP]
+> 对于 Gaudi 优化的扩散管道实现,我们推荐使用 [Optimum for Intel Gaudi](https://huggingface.co/docs/optimum/main/en/habana/index)。
\ No newline at end of file
diff --git a/docs/source/zh/optimization/memory.md b/docs/source/zh/optimization/memory.md
new file mode 100644
index 000000000000..662dcaf4bcf2
--- /dev/null
+++ b/docs/source/zh/optimization/memory.md
@@ -0,0 +1,581 @@
+
+
+# 减少内存使用
+
+现代diffusion models,如 [Flux](../api/pipelines/flux) 和 [Wan](../api/pipelines/wan),拥有数十亿参数,在您的硬件上进行推理时会占用大量内存。这是一个挑战,因为常见的 GPU 通常没有足够的内存。为了克服内存限制,您可以使用多个 GPU(如果可用)、将一些管道组件卸载到 CPU 等。
+
+本指南将展示如何减少内存使用。
+
+> [!TIP]
+> 请记住,这些技术可能需要根据模型进行调整。例如,基于 transformer 的扩散模型可能不会像基于 UNet 的模型那样从这些内存优化中同等受益。
+
+## 多个 GPU
+
+如果您有多个 GPU 的访问权限,有几种选项可以高效地在硬件上加载和分发大型模型。这些功能由 [Accelerate](https://huggingface.co/docs/accelerate/index) 库支持,因此请确保先安装它。
+
+```bash
+pip install -U accelerate
+```
+
+### 分片检查点
+
+将大型检查点加载到多个分片中很有用,因为分片会逐个加载。这保持了低内存使用,只需要足够的内存来容纳模型大小和最大分片大小。我们建议当 fp32 检查点大于 5GB 时进行分片。默认分片大小为 5GB。
+
+在 [`~DiffusionPipeline.save_pretrained`] 中使用 `max_shard_size` 参数对检查点进行分片。
+
+```py
+from diffusers import AutoModel
+
+unet = AutoModel.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet"
+)
+unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
+```
+
+现在您可以使用分片检查点,而不是常规检查点,以节省内存。
+
+```py
+import torch
+from diffusers import AutoModel, StableDiffusionXLPipeline
+
+unet = AutoModel.from_pretrained(
+ "username/sdxl-unet-sharded", torch_dtype=torch.float16
+)
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ unet=unet,
+ torch_dtype=torch.float16
+).to("cuda")
+```
+
+### 设备放置
+
+> [!WARNING]
+> 设备放置是一个实验性功能,API 可能会更改。目前仅支持 `balanced` 策略。我们计划在未来支持额外的映射策略。
+
+`device_map` 参数控制管道或模型中的组件如何
+单个模型中的层分布在多个设备上。
+
+
+
+
+`balanced` 设备放置策略将管道均匀分割到所有可用设备上。
+
+```py
+import torch
+from diffusers import AutoModel, StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ device_map="balanced"
+)
+```
+
+您可以使用 `hf_device_map` 检查管道的设备映射。
+
+```py
+print(pipeline.hf_device_map)
+{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
+```
+
+
+
+
+`device_map` 对于加载大型模型非常有用,例如具有 125 亿参数的 Flux diffusion transformer。将其设置为 `"auto"` 可以自动将模型首先分布到最快的设备上,然后再移动到较慢的设备。有关更多详细信息,请参阅 [模型分片](../training/distributed_inference#model-sharding) 文档。
+
+```py
+import torch
+from diffusers import AutoModel
+
+transformer = AutoModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ device_map="auto",
+ torch_dtype=torch.bfloat16
+)
+```
+
+您可以使用 `hf_device_map` 检查模型的设备映射。
+
+```py
+print(transformer.hf_device_map)
+```
+
+
+
+
+当设计您自己的 `device_map` 时,它应该是一个字典,包含模型的特定模块名称或层以及设备标识符(整数表示 GPU,`cpu` 表示 CPU,`disk` 表示磁盘)。
+
+在模型上调用 `hf_device_map` 以查看模型层如何分布,然后设计您自己的映射。
+
+```py
+print(transformer.hf_device_map)
+{'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 'cpu', 'single_transformer_blocks.11': 'cpu', 'single_transformer_blocks.12': 'cpu', 'single_transformer_blocks.13': 'cpu', 'single_transformer_blocks.14': 'cpu', 'single_transformer_blocks.15': 'cpu', 'single_transformer_blocks.16': 'cpu', 'single_transformer_blocks.17': 'cpu', 'single_transformer_blocks.18': 'cpu', 'single_transformer_blocks.19': 'cpu', 'single_transformer_blocks.20': 'cpu', 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'}
+```
+
+例如,下面的 `device_map` 将 `single_transformer_blocks.10` 到 `single_transformer_blocks.20` 放置在第二个 GPU(`1`)上。
+
+```py
+import torch
+from diffusers import AutoModel
+
+device_map = {
+ 'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 1, 'single_transformer_blocks.11': 1, 'single_transformer_blocks.12': 1, 'single_transformer_blocks.13': 1, 'single_transformer_blocks.14': 1, 'single_transformer_blocks.15': 1, 'single_transformer_blocks.16': 1, 'single_transformer_blocks.17': 1, 'single_transformer_blocks.18': 1, 'single_transformer_blocks.19': 1, 'single_transformer_blocks.20': 1, 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'
+}
+
+transformer = AutoModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ device_map=device_map,
+ torch_dtype=torch.bfloat16
+)
+```
+
+传递一个字典,将最大内存使用量映射到每个设备以强制执行限制。如果设备不在 `max_memory` 中,它将被忽略,管道组件不会分发到该设备。
+
+```py
+import torch
+from diffusers import AutoModel, StableDiffusionXLPipeline
+
+max_memory = {0:"1GB", 1:"1GB"}
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ device_map="balanced",
+ max_memory=max_memory
+)
+```
+
+Diffusers 默认使用所有设备的最大内存,但如果它们无法适应 GPU,则需要使用单个 GPU 并通过以下方法卸载到 CPU。
+
+- [`~DiffusionPipeline.enable_model_cpu_offload`] 仅适用于单个 GPU,但非常大的模型可能无法适应它
+- 使用 [`~DiffusionPipeline.enable_sequential_cpu_offload`] 可能有效,但它极其缓慢,并且仅限于单个 GPU。
+
+使用 [`~DiffusionPipeline.reset_device_map`] 方法来重置 `device_map`。如果您想在已进行设备映射的管道上使用方法如 `.to()`、[`~DiffusionPipeline.enable_sequential_cpu_offload`] 和 [`~DiffusionPipeline.enable_model_cpu_offload`],这是必要的。
+
+```py
+pipeline.reset_device_map()
+```
+
+## VAE 切片
+
+VAE 切片通过将大批次输入拆分为单个数据批次并分别处理它们来节省内存。这种方法在同时生成多个图像时效果最佳。
+
+例如,如果您同时生成 4 个图像,解码会将峰值激活内存增加 4 倍。VAE 切片通过一次只解码 1 个图像而不是所有 4 个图像来减少这种情况。
+
+调用 [`~StableDiffusionPipeline.enable_vae_slicing`] 来启用切片 VAE。您可以预期在解码多图像批次时性能会有小幅提升,而在单图像批次时没有性能影响。
+
+```py
+import torch
+from diffusers import AutoModel, StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+).to("cuda")
+pipeline.enable_vae_slicing()
+pipeline(["An astronaut riding a horse on Mars"]*32).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+```
+
+> [!WARNING]
+> [`AutoencoderKLWan`] 和 [`AsymmetricAutoencoderKL`] 类不支持切片。
+
+## VAE 平铺
+
+VAE 平铺通过将图像划分为较小的重叠图块而不是一次性处理整个图像来节省内存。这也减少了峰值内存使用量,因为 GPU 一次只处理一个图块。
+
+调用 [`~StableDiffusionPipeline.enable_vae_tiling`] 来启用 VAE 平铺。生成的图像可能因图块到图块的色调变化而有所不同,因为它们被单独解码,但图块之间不应有明显的接缝。对于低于预设(但可配置)限制的分辨率,平铺被禁用。例如,对于 [`StableDiffusionPipeline`] 中的 VAE,此限制为 512x512。
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import load_image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+).to("cuda")
+pipeline.enable_vae_tiling()
+
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png")
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+pipeline(prompt, image=init_image, strength=0.5).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+```
+
+> [!WARNING]
+> [`AutoencoderKLWan`] 和 [`AsymmetricAutoencoderKL`] 不支持平铺。
+
+## 卸载
+
+卸载策略将非当前活动层移动
+将模型移动到 CPU 以避免增加 GPU 内存。这些策略可以与量化和 torch.compile 结合使用,以平衡推理速度和内存使用。
+
+有关更多详细信息,请参考 [编译和卸载量化模型](./speed-memory-optims) 指南。
+
+### CPU 卸载
+
+CPU 卸载选择性地将权重从 GPU 移动到 CPU。当需要某个组件时,它被传输到 GPU;当不需要时,它被移动到 CPU。此方法作用于子模块而非整个模型。它通过避免将整个模型存储在 GPU 上来节省内存。
+
+CPU 卸载显著减少内存使用,但由于子模块在设备之间多次来回传递,它也非常慢。由于速度极慢,它通常不实用。
+
+> [!WARNING]
+> 在调用 [`~DiffusionPipeline.enable_sequential_cpu_offload`] 之前,不要将管道移动到 CUDA,否则节省的内存非常有限(更多细节请参考此 [issue](https://github.com/huggingface/diffusers/issues/1934))。这是一个状态操作,会在模型上安装钩子。
+
+调用 [`~DiffusionPipeline.enable_sequential_cpu_offload`] 以在管道上启用它。
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
+)
+pipeline.enable_sequential_cpu_offload()
+
+pipeline(
+ prompt="An astronaut riding a horse on Mars",
+ guidance_scale=0.,
+ height=768,
+ width=1360,
+ num_inference_steps=4,
+ max_sequence_length=256,
+).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+```
+
+### 模型卸载
+
+模型卸载将整个模型移动到 GPU,而不是选择性地移动某些层或模型组件。一个主要管道模型,通常是文本编码器、UNet 和 VAE,被放置在 GPU 上,而其他组件保持在 CPU 上。像 UNet 这样运行多次的组件会一直留在 GPU 上,直到完全完成且不再需要。这消除了 [CPU 卸载](#cpu-offloading) 的通信开销,使模型卸载成为一个更快的替代方案。权衡是内存节省不会那么大。
+
+> [!WARNING]
+> 请注意,如果在安装钩子后模型在管道外部被重用(更多细节请参考 [移除钩子](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module)),您需要按预期顺序运行整个管道和模型以正确卸载它们。这是一个状态操作,会在模型上安装钩子。
+
+调用 [`~DiffusionPipeline.enable_model_cpu_offload`] 以在管道上启用它。
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
+)
+pipeline.enable_model_cpu_offload()
+
+pipeline(
+ prompt="An astronaut riding a horse on Mars",
+ guidance_scale=0.,
+ height=768,
+ width=1360,
+ num_inference_steps=4,
+ max_sequence_length=256,
+).images[0]
+print(f"最大内存保留: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+```
+
+[`~DiffusionPipeline.enable_model_cpu_offload`] 在您单独使用 [`~StableDiffusionXLPipeline.encode_prompt`] 方法生成文本编码器隐藏状态时也有帮助。
+
+### 组卸载
+
+组卸载将内部层组([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) 或 [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html))移动到 CPU。它比[模型卸载](#model-offloading)使用更少的内存,并且比[CPU 卸载](#cpu-offloading)更快,因为它减少了通信开销。
+
+> [!WARNING]
+> 如果前向实现包含权重相关的输入设备转换,组卸载可能不适用于所有模型,因为它可能与组卸载的设备转换机制冲突。
+
+调用 [`~ModelMixin.enable_group_offload`] 为继承自 [`ModelMixin`] 的标准 Diffusers 模型组件启用它。对于不继承自 [`ModelMixin`] 的其他模型组件,例如通用 [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html),使用 [`~hooks.apply_group_offloading`] 代替。
+
+`offload_type` 参数可以设置为 `block_level` 或 `leaf_level`。
+
+- `block_level` 基于 `num_blocks_per_group` 参数卸载层组。例如,如果 `num_blocks_per_group=2` 在一个有 40 层的模型上,每次加载和卸载 2 层(总共 20 次加载/卸载)。这大大减少了内存需求。
+- `leaf_level` 在最低级别卸载单个层,等同于[CPU 卸载](#cpu-offloading)。但如果您使用流而不放弃推理速度,它可以更快。
+
+```py
+import torch
+from diffusers import CogVideoXPipeline
+from diffusers.hooks import apply_group_offloading
+from diffusers.utils import export_to_video
+
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+
+# 对 Diffusers 模型实现使用 enable_group_offload 方法
+pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level")
+pipeline.vae.enable_group_offload(onload_device=onload_device, offload_type="leaf_level")
+
+# 对其他模型组件使用 apply_group_offloading 方法
+apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
+
+prompt = (
+"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ "atmosphere of this unique musical performance."
+)
+video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+export_to_video(video, "output.mp4", fps=8)
+```
+
+#### CUDA 流
+`use_stream` 参数可以激活支持异步数据传输流的 CUDA 设备,以减少整体执行时间,与 [CPU 卸载](#cpu-offloading) 相比。它通过使用层预取重叠数据传输和计算。下一个要执行的层在当前层仍在执行时加载到 GPU 上。这会显著增加 CPU 内存,因此请确保您有模型大小的 2 倍内存。
+
+设置 `record_stream=True` 以获得更多速度提升,代价是内存使用量略有增加。请参阅 [torch.Tensor.record_stream](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) 文档了解更多信息。
+
+> [!TIP]
+> 当 `use_stream=True` 在启用平铺的 VAEs 上时,确保在推理前进行虚拟前向传递(可以使用虚拟输入),以避免设备不匹配错误。这可能不适用于所有实现,因此如果遇到任何问题,请随时提出问题。
+
+如果您在使用启用 `use_stream` 的 `block_level` 组卸载,`num_blocks_per_group` 参数应设置为 `1`,否则会引发警告。
+
+```py
+pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True)
+```
+
+`low_cpu_mem_usage` 参数可以设置为 `True`,以在使用流进行组卸载时减少 CPU 内存使用。它最适合 `leaf_level` 卸载和 CPU 内存瓶颈的情况。通过动态创建固定张量而不是预先固定它们来节省内存。然而,这可能会增加整体执行时间。
+
+#### 卸载到磁盘
+组卸载可能会消耗大量系统内存,具体取决于模型大小。在内存有限的系统上,尝试将组卸载到磁盘作为辅助内存。
+
+在 [`~ModelMixin.enable_group_offload`] 或 [`~hooks.apply_group_offloading`] 中设置 `offload_to_disk_path` 参数,将模型卸载到磁盘。
+
+```py
+pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", offload_to_disk_path="path/to/disk")
+
+apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2, offload_to_disk_path="path/to/disk")
+```
+
+参考这些[两个](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363)[表格](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126)来比较速度和内存的权衡。
+
+## 分层类型转换
+
+> [!TIP]
+> 将分层类型转换与[组卸载](#group-offloading)结合使用,以获得更多内存节省。
+
+分层类型转换将权重存储在较小的数据格式中(例如 `torch.float8_e4m3fn` 和 `torch.float8_e5m2`),以使用更少的内存,并在计算时将那些权重上转换为更高精度如 `torch.float16` 或 `torch.bfloat16`。某些层(归一化和调制相关权重)被跳过,因为将它们存储在 fp8 中可能会降低生成质量。
+
+> [!WARNING]
+> 如果前向实现包含权重的内部类型转换,分层类型转换可能不适用于所有模型。当前的分层类型转换实现假设前向传递独立于权重精度,并且输入数据类型始终在 `compute_dtype` 中指定(请参见[这里](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299)以获取不兼容的实现)。
+>
+> 分层类型转换也可能在使用[PEFT](https://huggingface.co/docs/peft/index)层的自定义建模实现上失败。有一些检查可用,但它们没有经过广泛测试或保证在所有情况下都能工作。
+
+调用 [`~ModelMixin.enable_layerwise_casting`] 来设置存储和计算数据类型。
+
+```py
+import torch
+from diffusers import CogVideoXPipeline, CogVideoXTransformer3DModel
+from diffusers.utils import export_to_video
+
+transformer = CogVideoXTransformer3DModel.from_pretrained(
+ "THUDM/CogVideoX-5b",
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16
+)
+transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
+
+pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16
+).to("cuda")
+prompt = (
+ "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ "atmosphere of this unique musical performance."
+)
+video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+export_to_video(video, "output.mp4", fps=8)
+```
+
+[`~hooks.apply_layerwise_casting`] 方法也可以在您需要更多控制和灵活性时使用。它可以通过在特定内部模块上调用它来部分应用于模型层。使用 `skip_modules_pattern` 或 `skip_modules_classes` 参数来指定要避免的模块,例如归一化和调制层。
+
+```python
+import torch
+from diffusers import CogVideoXTransformer3DModel
+from diffusers.hooks import apply_layerwise_casting
+
+transformer = CogVideoXTransformer3DModel.from_pretrained(
+ "THUDM/CogVideoX-5b",
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16
+)
+
+# 跳过归一化层
+apply_layerwise_casting(
+ transformer,
+ storage_dtype=torch.float8_e4m3fn,
+ compute_dtype=torch.bfloat16,
+ skip_modules_classes=["norm"],
+ non_blocking=True,
+)
+```
+
+## torch.channels_last
+
+[torch.channels_last](https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html) 将张量的存储方式从 `(批次大小, 通道数, 高度, 宽度)` 翻转为 `(批次大小, 高度, 宽度, 通道数)`。这使张量与硬件如何顺序访问存储在内存中的张量对齐,并避免了在内存中跳转以访问像素值。
+
+并非所有运算符当前都支持通道最后格式,并且可能导致性能更差,但仍然值得尝试。
+
+```py
+print(pipeline.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1)
+pipeline.unet.to(memory_format=torch.channels_last) # 原地操作
+print(
+ pipeline.unet.conv_out.state_dict()["weight"].stride()
+) # (2880, 1, 960, 320) 第二个维度的跨度为1证明它有效
+```
+
+## torch.jit.trace
+
+[torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) 记录模型在样本输入上执行的操作,并根据记录的执行路径创建一个新的、优化的模型表示。在跟踪过程中,模型被优化以减少来自Python和动态控制流的开销,并且操作被融合在一起以提高效率。返回的可执行文件或 [ScriptFunction](https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html) 可以被编译。
+
+```py
+import time
+import torch
+from diffusers import StableDiffusionPipeline
+import functools
+
+# torch 禁用梯度
+torch.set_grad_enabled(False)
+
+# 设置变量
+n_experiments = 2
+unet_runs_per_experiment = 50
+
+# 加载样本输入
+def generate_inputs():
+ sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
+ timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
+ encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
+ return sample, timestep, encoder_hidden_states
+
+
+pipeline = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+).to("cuda")
+unet = pipeline.unet
+unet.eval()
+unet.to(memory
+_format=torch.channels_last) # 使用 channels_last 内存格式
+unet.forward = functools.partial(unet.forward, return_dict=False) # 设置 return_dict=False 为默认
+
+# 预热
+for _ in range(3):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet(*inputs)
+
+# 追踪
+print("tracing..")
+unet_traced = torch.jit.trace(unet, inputs)
+unet_traced.eval()
+print("done tracing")
+
+# 预热和优化图
+for _ in range(5):
+ with torch.inference_mode():
+ inputs = generate_inputs()
+ orig_output = unet_traced(*inputs)
+
+# 基准测试
+with torch.inference_mode():
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet_traced(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
+ for _ in range(n_experiments):
+ torch.cuda.synchronize()
+ start_time = time.time()
+ for _ in range(unet_runs_per_experiment):
+ orig_output = unet(*inputs)
+ torch.cuda.synchronize()
+ print(f"unet inference took {time.time() - start_time:.2f} seconds")
+
+# 保存模型
+unet_traced.save("unet_traced.pt")
+```
+
+替换管道的 UNet 为追踪版本。
+
+```py
+import torch
+from diffusers import StableDiffusionPipeline
+from dataclasses import dataclass
+
+@dataclass
+class UNet2DConditionOutput:
+ sample: torch.Tensor
+
+pipeline = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+).to("cuda")
+
+# 使用 jitted unet
+unet_traced = torch.jit.load("unet_traced.pt")
+
+# del pipeline.unet
+class TracedUNet(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.in_channels = pipe.unet.config.in_channels
+ self.device = pipe.unet.device
+
+ def forward(self, latent_model_input, t, encoder_hidden_states):
+ sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
+ return UNet2DConditionOutput(sample=sample)
+
+pipeline.unet = TracedUNet()
+
+with torch.inference_mode():
+ image = pipe([prompt] * 1, num_inference_steps=50).images[0]
+```
+
+## 内存高效注意力
+
+> [!TIP]
+> 内存高效注意力优化内存使用 *和* [推理速度](./fp16#scaled-dot-product-attention)!
+
+Transformers 注意力机制是内存密集型的,尤其对于长序列,因此您可以尝试使用不同且更内存高效的注意力类型。
+
+默认情况下,如果安装了 PyTorch >= 2.0,则使用 [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)。您无需对代码进行任何额外更改。
+
+SDPA 还支持 [FlashAttention](https://github.com/Dao-AILab/flash-attention) 和 [xFormers](https://github.com/facebookresearch/xformers),以及 a
+这是一个原生的 C++ PyTorch 实现。它会根据您的输入自动选择最优的实现。
+
+您可以使用 [`~ModelMixin.enable_xformers_memory_efficient_attention`] 方法显式地使用 xFormers。
+
+```py
+# pip install xformers
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+).to("cuda")
+pipeline.enable_xformers_memory_efficient_attention()
+```
+
+调用 [`~ModelMixin.disable_xformers_memory_efficient_attention`] 来禁用它。
+
+```py
+pipeline.disable_xformers_memory_efficient_attention()
+```
\ No newline at end of file
diff --git a/docs/source/zh/optimization/mps.md b/docs/source/zh/optimization/mps.md
new file mode 100644
index 000000000000..48b08c5a12df
--- /dev/null
+++ b/docs/source/zh/optimization/mps.md
@@ -0,0 +1,79 @@
+
+
+# Metal Performance Shaders (MPS)
+
+> [!TIP]
+> 带有 徽章的管道表示模型可以利用 Apple silicon 设备上的 MPS 后端进行更快的推理。欢迎提交 [Pull Request](https://github.com/huggingface/diffusers/compare) 来为缺少此徽章的管道添加它。
+
+🤗 Diffusers 与 Apple silicon(M1/M2 芯片)兼容,使用 PyTorch 的 [`mps`](https://pytorch.org/docs/stable/notes/mps.html) 设备,该设备利用 Metal 框架来发挥 MacOS 设备上 GPU 的性能。您需要具备:
+
+- 配备 Apple silicon(M1/M2)硬件的 macOS 计算机
+- macOS 12.6 或更高版本(推荐 13.0 或更高)
+- arm64 版本的 Python
+- [PyTorch 2.0](https://pytorch.org/get-started/locally/)(推荐)或 1.13(支持 `mps` 的最低版本)
+
+`mps` 后端使用 PyTorch 的 `.to()` 接口将 Stable Diffusion 管道移动到您的 M1 或 M2 设备上:
+
+```python
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
+pipe = pipe.to("mps")
+
+# 如果您的计算机内存小于 64 GB,推荐使用
+pipe.enable_attention_slicing()
+
+prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+image
+```
+
+> [!WARNING]
+> PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) 后端不支持大小超过 `2**32` 的 NDArray。如果您遇到此问题,请提交 [Issue](https://github.com/huggingface/diffusers/issues/new/choose) 以便我们调查。
+
+如果您使用 **PyTorch 1.13**,您需要通过管道进行一次额外的"预热"传递。这是一个临时解决方法,用于解决首次推理传递产生的结果与后续传递略有不同的问题。您只需要执行此传递一次,并且在仅进行一次推理步骤后可以丢弃结果。
+
+```diff
+ from diffusers import DiffusionPipeline
+
+ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5").to("mps")
+ pipe.enable_attention_slicing()
+
+ prompt = "a photo of an astronaut riding a horse on mars"
+ # 如果 PyTorch 版本是 1.13,进行首次"预热"传递
++ _ = pipe(prompt, num_inference_steps=1)
+
+ # 预热传递后,结果与 CPU 设备上的结果匹配。
+ image = pipe(prompt).images[0]
+```
+
+## 故障排除
+
+本节列出了使用 `mps` 后端时的一些常见问题及其解决方法。
+
+### 注意力切片
+
+M1/M2 性能对内存压力非常敏感。当发生这种情况时,系统会自动交换内存,这会显著降低性能。
+
+为了防止这种情况发生,我们建议使用*注意力切片*来减少推理过程中的内存压力并防止交换。这在您的计算机系统内存少于 64GB 或生成非标准分辨率(大于 512×512 像素)的图像时尤其相关。在您的管道上调用 [`~DiffusionPipeline.enable_attention_slicing`] 函数:
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("mps")
+pipeline.enable_attention_slicing()
+```
+
+注意力切片将昂贵的注意力操作分多个步骤执行,而不是一次性完成。在没有统一内存的计算机中,它通常能提高约 20% 的性能,但我们观察到在大多数 Apple 芯片计算机中,除非您有 64GB 或更多 RAM,否则性能会*更好*。
+
+### 批量推理
+
+批量生成多个提示可能会导致崩溃或无法可靠工作。如果是这种情况,请尝试迭代而不是批量处理。
\ No newline at end of file
diff --git a/docs/source/zh/optimization/neuron.md b/docs/source/zh/optimization/neuron.md
new file mode 100644
index 000000000000..99d807a88c0d
--- /dev/null
+++ b/docs/source/zh/optimization/neuron.md
@@ -0,0 +1,56 @@
+
+
+# AWS Neuron
+
+Diffusers 功能可在 [AWS Inf2 实例](https://aws.amazon.com/ec2/instance-types/inf2/)上使用,这些是由 [Neuron 机器学习加速器](https://aws.amazon.com/machine-learning/inferentia/)驱动的 EC2 实例。这些实例旨在提供更好的计算性能(更高的吞吐量、更低的延迟)和良好的成本效益,使其成为 AWS 用户将扩散模型部署到生产环境的良好选择。
+
+[Optimum Neuron](https://huggingface.co/docs/optimum-neuron/en/index) 是 Hugging Face 库与 AWS 加速器之间的接口,包括 AWS [Trainium](https://aws.amazon.com/machine-learning/trainium/) 和 AWS [Inferentia](https://aws.amazon.com/machine-learning/inferentia/)。它支持 Diffusers 中的许多功能,并具有类似的 API,因此如果您已经熟悉 Diffusers,学习起来更容易。一旦您创建了 AWS Inf2 实例,请安装 Optimum Neuron。
+
+```bash
+python -m pip install --upgrade-strategy eager optimum[neuronx]
+```
+
+> [!TIP]
+> 我们提供预构建的 [Hugging Face Neuron 深度学习 AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2)(DLAMI)和用于 Amazon SageMaker 的 Optimum Neuron 容器。建议正确设置您的环境。
+
+下面的示例演示了如何在 inf2.8xlarge 实例上使用 Stable Diffusion XL 模型生成图像(一旦模型编译完成,您可以切换到更便宜的 inf2.xlarge 实例)。要生成一些图像,请使用 [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] 类,该类类似于 Diffusers 中的 [`StableDiffusionXLPipeline`] 类。
+
+与 Diffusers 不同,您需要将管道中的模型编译为 Neuron 格式,即 `.neuron`。运行以下命令将模型导出为 `.neuron` 格式。
+
+```bash
+optimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \
+ --batch_size 1 \
+ --height 1024 `# 生成图像的高度(像素),例如 768, 1024` \
+ --width 1024 `# 生成图像的宽度(像素),例如 768, 1024` \
+ --num_images_per_prompt 1 `# 每个提示生成的图像数量,默认为 1` \
+ --auto_cast matmul `# 仅转换矩阵乘法操作` \
+ --auto_cast_type bf16 `# 将操作从 FP32 转换为 BF16` \
+ sd_neuron_xl/
+```
+
+现在使用预编译的 SDXL 模型生成一些图像。
+
+```python
+>>> from optimum.neuron import Neu
+ronStableDiffusionXLPipeline
+
+>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained("sd_neuron_xl/")
+>>> prompt = "a pig with wings flying in floating US dollar banknotes in the air, skyscrapers behind, warm color palette, muted colors, detailed, 8k"
+>>> image = stable_diffusion_xl(prompt).images[0]
+```
+
+
+
+欢迎查看Optimum Neuron [文档](https://huggingface.co/docs/optimum-neuron/en/inference_tutorials/stable_diffusion#generate-images-with-stable-diffusion-models-on-aws-inferentia)中更多不同用例的指南和示例!
\ No newline at end of file
diff --git a/docs/source/zh/optimization/onnx.md b/docs/source/zh/optimization/onnx.md
new file mode 100644
index 000000000000..b70510d51b75
--- /dev/null
+++ b/docs/source/zh/optimization/onnx.md
@@ -0,0 +1,79 @@
+
+
+# ONNX Runtime
+
+🤗 [Optimum](https://github.com/huggingface/optimum) 提供了兼容 ONNX Runtime 的 Stable Diffusion 流水线。您需要运行以下命令安装支持 ONNX Runtime 的 🤗 Optimum:
+
+```bash
+pip install -q optimum["onnxruntime"]
+```
+
+本指南将展示如何使用 ONNX Runtime 运行 Stable Diffusion 和 Stable Diffusion XL (SDXL) 流水线。
+
+## Stable Diffusion
+
+要加载并运行推理,请使用 [`~optimum.onnxruntime.ORTStableDiffusionPipeline`]。若需加载 PyTorch 模型并实时转换为 ONNX 格式,请设置 `export=True`:
+
+```python
+from optimum.onnxruntime import ORTStableDiffusionPipeline
+
+model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+pipeline = ORTStableDiffusionPipeline.from_pretrained(model_id, export=True)
+prompt = "sailing ship in storm by Leonardo da Vinci"
+image = pipeline(prompt).images[0]
+pipeline.save_pretrained("./onnx-stable-diffusion-v1-5")
+```
+
+> [!WARNING]
+> 当前批量生成多个提示可能会占用过高内存。在问题修复前,建议采用迭代方式而非批量处理。
+
+如需离线导出 ONNX 格式流水线供后续推理使用,请使用 [`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) 命令:
+
+```bash
+optimum-cli export onnx --model stable-diffusion-v1-5/stable-diffusion-v1-5 sd_v15_onnx/
+```
+
+随后进行推理时(无需再次指定 `export=True`):
+
+```python
+from optimum.onnxruntime import ORTStableDiffusionPipeline
+
+model_id = "sd_v15_onnx"
+pipeline = ORTStableDiffusionPipeline.from_pretrained(model_id)
+prompt = "sailing ship in storm by Leonardo da Vinci"
+image = pipeline(prompt).images[0]
+```
+
+
+
+
+
+您可以在 🤗 Optimum [文档](https://huggingface.co/docs/optimum/) 中找到更多示例,Stable Diffusion 支持文生图、图生图和图像修复任务。
+
+## Stable Diffusion XL
+
+要加载并运行 SDXL 推理,请使用 [`~optimum.onnxruntime.ORTStableDiffusionXLPipeline`]:
+
+```python
+from optimum.onnxruntime import ORTStableDiffusionXLPipeline
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+pipeline = ORTStableDiffusionXLPipeline.from_pretrained(model_id)
+prompt = "sailing ship in storm by Leonardo da Vinci"
+image = pipeline(prompt).images[0]
+```
+
+如需导出 ONNX 格式流水线供后续推理使用,请运行:
+
+```bash
+optimum-cli export onnx --model stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl sd_xl_onnx/
+```
+
+SDXL 的 ONNX 格式目前支持文生图和图生图任务。
diff --git a/docs/source/zh/optimization/open_vino.md b/docs/source/zh/optimization/open_vino.md
new file mode 100644
index 000000000000..8229c5a9448a
--- /dev/null
+++ b/docs/source/zh/optimization/open_vino.md
@@ -0,0 +1,77 @@
+
+
+# OpenVINO
+
+🤗 [Optimum](https://github.com/huggingface/optimum-intel) 提供与 OpenVINO 兼容的 Stable Diffusion 管道,可在各种 Intel 处理器上执行推理(请参阅支持的设备[完整列表](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html))。
+
+您需要安装 🤗 Optimum Intel,并使用 `--upgrade-strategy eager` 选项以确保 [`optimum-intel`](https://github.com/huggingface/optimum-intel) 使用最新版本:
+
+```bash
+pip install --upgrade-strategy eager optimum["openvino"]
+```
+
+本指南将展示如何使用 Stable Diffusion 和 Stable Diffusion XL (SDXL) 管道与 OpenVINO。
+
+## Stable Diffusion
+
+要加载并运行推理,请使用 [`~optimum.intel.OVStableDiffusionPipeline`]。如果您想加载 PyTorch 模型并即时转换为 OpenVINO 格式,请设置 `export=True`:
+
+```python
+from optimum.intel import OVStableDiffusionPipeline
+
+model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+pipeline = OVStableDiffusionPipeline.from_pretrained(model_id, export=True)
+prompt = "sailing ship in storm by Rembrandt"
+image = pipeline(prompt).images[0]
+
+# 别忘了保存导出的模型
+pipeline.save_pretrained("openvino-sd-v1-5")
+```
+
+为了进一步加速推理,静态重塑模型。如果您更改任何参数,例如输出高度或宽度,您需要再次静态重塑模型。
+
+```python
+# 定义与输入和期望输出相关的形状
+batch_size, num_images, height, width = 1, 1, 512, 512
+
+# 静态重塑模型
+pipeline.reshape(batch_size, height, width, num_images)
+# 在推理前编译模型
+pipeline.compile()
+
+image = pipeline(
+ prompt,
+ height=height,
+ width=width,
+ num_images_per_prompt=num_images,
+).images[0]
+```
+
+
+
+
+您可以在 🤗 Optimum [文档](https://huggingface.co/docs/optimum/intel/inference#stable-diffusion) 中找到更多示例,Stable Diffusion 支持文本到图像、图像到图像和修复。
+
+## Stable Diffusion XL
+
+要加载并运行 SDXL 推理,请使用 [`~optimum.intel.OVStableDiffusionXLPipeline`]:
+
+```python
+from optimum.intel import OVStableDiffusionXLPipeline
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+pipeline = OVStableDiffusionXLPipeline.from_pretrained(model_id)
+prompt = "sailing ship in storm by Rembrandt"
+image = pipeline(prompt).images[0]
+```
+
+为了进一步加速推理,可以如Stable Diffusion部分所示[静态重塑](#stable-diffusion)模型。
+
+您可以在🤗 Optimum[文档](https://huggingface.co/docs/optimum/intel/inference#stable-diffusion-xl)中找到更多示例,并且在OpenVINO中运行SDXL支持文本到图像和图像到图像。
\ No newline at end of file
diff --git a/docs/source/zh/optimization/para_attn.md b/docs/source/zh/optimization/para_attn.md
new file mode 100644
index 000000000000..106a8818c643
--- /dev/null
+++ b/docs/source/zh/optimization/para_attn.md
@@ -0,0 +1,497 @@
+# ParaAttention
+
+
+
+
+
+
+
+
+大型图像和视频生成模型,如 [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) 和 [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo),由于其规模,可能对实时应用和部署构成推理挑战。
+
+[ParaAttention](https://github.com/chengzeyi/ParaAttention) 是一个实现了**上下文并行**和**第一块缓存**的库,可以与其他技术(如 torch.compile、fp8 动态量化)结合使用,以加速推理。
+
+本指南将展示如何在 NVIDIA L20 GPU 上对 FLUX.1-dev 和 HunyuanVideo 应用 ParaAttention。
+在我们的基线基准测试中,除了 HunyuanVideo 为避免内存不足错误外,未应用任何优化。
+
+我们的基线基准测试显示,FLUX.1-dev 能够在 28 步中生成 1024x1024 分辨率图像,耗时 26.36 秒;HunyuanVideo 能够在 30 步中生成 129 帧 720p 分辨率视频,耗时 3675.71 秒。
+
+> [!TIP]
+> 对于更快的上下文并行推理,请尝试使用支持 NVLink 的 NVIDIA A100 或 H100 GPU(如果可用),尤其是在 GPU 数量较多时。
+
+## 第一块缓存
+
+缓存模型中 transformer 块的输出并在后续推理步骤中重用它们,可以降低计算成本并加速推理。
+
+然而,很难决定何时重用缓存以确保生成图像或视频的质量。ParaAttention 直接使用**第一个 transformer 块输出的残差差异**来近似模型输出之间的差异。当差异足够小时,重用先前推理步骤的残差差异。换句话说,跳过去噪步骤。
+
+这在 FLUX.1-dev 和 HunyuanVideo 推理上实现了 2 倍加速,且质量非常好。
+
+
+
+ AdaCache 的工作原理,第一块缓存是其变体
+
+
+
+
+
+要在 FLUX.1-dev 上应用第一块缓存,请调用 `apply_cache_on_pipe`,如下所示。0.08 是 FLUX 模型的默认残差差异值。
+
+```python
+import time
+import torch
+from diffusers import FluxPipeline
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(pipe, residual_diff_thre
+shold=0.08)
+
+# 启用内存节省
+# pipe.enable_model_cpu_offload()
+# pipe.enable_sequential_cpu_offload()
+
+begin = time.time()
+image = pipe(
+ "A cat holding a sign that says hello world",
+ num_inference_steps=28,
+).images[0]
+end = time.time()
+print(f"Time: {end - begin:.2f}s")
+
+print("Saving image to flux.png")
+image.save("flux.png")
+```
+
+| 优化 | 原始 | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 |
+| - | - | - | - | - | - |
+| 预览 |  |  |  |  |  |
+| 墙时间 (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 |
+
+First Block Cache 将推理速度降低到 17.01 秒,与基线相比,或快 1.55 倍,同时保持几乎零质量损失。
+
+
+
+
+要在 HunyuanVideo 上应用 First Block Cache,请使用 `apply_cache_on_pipe`,如下所示。0.06 是 HunyuanVideo 模型的默认残差差值。
+
+```python
+import time
+import torch
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+from diffusers.utils import export_to_video
+
+model_id = "tencent/HunyuanVideo"
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ revision="refs/pr/18",
+)
+pipe = HunyuanVideoPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=torch.float16,
+ revision="refs/pr/18",
+).to("cuda")
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(pipe, residual_diff_threshold=0.6)
+
+pipe.vae.enable_tiling()
+
+begin = time.time()
+output = pipe(
+ prompt="A cat walks on the grass, realistic",
+ height=720,
+ width=1280,
+ num_frames=129,
+ num_inference_steps=30,
+).frames[0]
+end = time.time()
+print(f"Time: {end - begin:.2f}s")
+
+print("Saving video to hunyuan_video.mp4")
+export_to_video(output, "hunyuan_video.mp4", fps=15)
+```
+
+
+
+ 您的浏览器不支持视频标签。
+
+
+ HunyuanVideo 无 FBCache
+
+
+
+ Your browser does not support the video tag.
+
+
+ HunyuanVideo 与 FBCache
+
+First Block Cache 将推理速度降低至 2271.06 秒,相比基线快了 1.62 倍,同时保持了几乎为零的质量损失。
+
+
+
+
+## fp8 量化
+
+fp8 动态量化进一步加速推理并减少内存使用。为了使用 8 位 [NVIDIA Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/),必须对激活和权重进行量化。
+
+使用 `float8_weight_only` 和 `float8_dynamic_activation_float8_weight` 来量化文本编码器和变换器模型。
+
+默认量化方法是逐张量量化,但如果您的 GPU 支持逐行量化,您也可以尝试它以获得更好的准确性。
+
+使用以下命令安装 [torchao](https://github.com/pytorch/ao/tree/main)。
+
+```bash
+pip3 install -U torch torchao
+```
+
+[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) 使用 `mode="max-autotune-no-cudagraphs"` 或 `mode="max-autotune"` 选择最佳内核以获得性能。如果是第一次调用模型,编译可能会花费很长时间,但一旦模型编译完成,这是值得的。
+
+此示例仅量化变换器模型,但您也可以量化文本编码器以进一步减少内存使用。
+
+> [!TIP]
+> 动态量化可能会显著改变模型输出的分布,因此您需要将 `residual_diff_threshold` 设置为更大的值以使其生效。
+
+
+
+
+```python
+import time
+import torch
+from diffusers import FluxPipeline
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(
+ pipe,
+ residual_diff_threshold=0.12, # 使用更大的值以使缓存生效
+)
+
+from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
+
+quantize_(pipe.text_encoder, float8_weight_only())
+quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
+pipe.transformer = torch.compile(
+ pipe.transformer, mode="max-autotune-no-cudagraphs",
+)
+
+# 启用内存节省
+# pipe.enable_model_cpu_offload()
+# pipe.enable_sequential_cpu_offload()
+
+for i in range(2):
+ begin = time.time()
+ image = pipe(
+ "A cat holding a sign that says hello world",
+ num_inference_steps=28,
+ ).images[0]
+ end = time.time()
+ if i == 0:
+ print(f"预热时间: {end - begin:.2f}s")
+ else:
+ print(f"时间: {end - begin:.2f}s")
+
+print("保存图像到 flux.png")
+image.save("flux.png")
+```
+
+fp8 动态量化和 torch.compile 将推理速度降低至 7.56 秒,相比基线快了 3.48 倍。
+
+
+
+```python
+import time
+import torch
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+from diffusers.utils import export_to_video
+
+model_id = "tencent/HunyuanVideo"
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ revision="refs/pr/18",
+)
+pipe = HunyuanVideoPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=torch.float16,
+ revision="refs/pr/18",
+).to("cuda")
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(pipe)
+
+from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
+
+quantize_(pipe.text_encoder, float8_weight_only())
+quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
+pipe.transformer = torch.compile(
+ pipe.transformer, mode="max-autotune-no-cudagraphs",
+)
+
+# Enable memory savings
+pipe.vae.enable_tiling()
+# pipe.enable_model_cpu_offload()
+# pipe.enable_sequential_cpu_offload()
+
+for i in range(2):
+ begin = time.time()
+ output = pipe(
+ prompt="A cat walks on the grass, realistic",
+ height=720,
+ width=1280,
+ num_frames=129,
+ num_inference_steps=1 if i == 0 else 30,
+ ).frames[0]
+ end = time.time()
+ if i == 0:
+ print(f"Warm up time: {end - begin:.2f}s")
+ else:
+ print(f"Time: {end - begin:.2f}s")
+
+print("Saving video to hunyuan_video.mp4")
+export_to_video(output, "hunyuan_video.mp4", fps=15)
+```
+
+NVIDIA L20 GPU 仅有 48GB 内存,在编译后且如果未调用 `enable_model_cpu_offload` 时,可能会遇到内存不足(OOM)错误,因为 HunyuanVideo 在高分辨率和大量帧数运行时具有非常大的激活张量。对于内存少于 80GB 的 GPU,可以尝试降低分辨率和帧数来避免 OOM 错误。
+
+大型视频生成模型通常受注意力计算而非全连接层的瓶颈限制。这些模型不会从量化和 torch.compile 中显著受益。
+
+
+
+
+## 上下文并行性
+
+上下文并行性并行化推理并随多个 GPU 扩展。ParaAttention 组合设计允许您将上下文并行性与第一块缓存和动态量化结合使用。
+
+> [!TIP]
+> 请参考 [ParaAttention](https://github.com/chengzeyi/ParaAttention/tree/main) 仓库获取详细说明和如何使用多个 GPU 扩展推理的示例。
+
+如果推理过程需要持久化和可服务,建议使用 [torch.multiprocessing](https://pytorch.org/docs/stable/multiprocessing.html) 编写您自己的推理处理器。这可以消除启动进程以及加载和重新编译模型的开销。
+
+
+
+
+以下代码示例结合了第一块缓存、fp8动态量化、torch.compile和上下文并行,以实现最快的推理速度。
+
+```python
+import time
+import torch
+import torch.distributed as dist
+from diffusers import FluxPipeline
+
+dist.init_process_group()
+
+torch.cuda.set_device(dist.get_rank())
+
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+from para_attn.context_parallel import init_context_parallel_mesh
+from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
+from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
+
+mesh = init_context_parallel_mesh(
+ pipe.device.type,
+ max_ring_dim_size=2,
+)
+parallelize_pipe(
+ pipe,
+ mesh=mesh,
+)
+parallelize_vae(pipe.vae, mesh=mesh._flatten())
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(
+ pipe,
+ residual_diff_threshold=0.12, # 使用较大的值以使缓存生效
+)
+
+from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
+
+quantize_(pipe.text_encoder, float8_weight_only())
+quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
+torch._inductor.config.reorder_for_compute_comm_overlap = True
+pipe.transformer = torch.compile(
+ pipe.transformer, mode="max-autotune-no-cudagraphs",
+)
+
+# 启用内存节省
+# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
+# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())
+
+for i in range(2):
+ begin = time.time()
+ image = pipe(
+ "A cat holding a sign that says hello world",
+ num_inference_steps=28,
+ output_type="pil" if dist.get_rank() == 0 else "pt",
+ ).images[0]
+ end = time.time()
+ if dist.get_rank() == 0:
+ if i == 0:
+ print(f"预热时间: {end - begin:.2f}s")
+ else:
+ print(f"时间: {end - begin:.2f}s")
+
+if dist.get_rank() == 0:
+ print("将图像保存到flux.png")
+ image.save("flux.png")
+
+dist.destroy_process_group()
+```
+
+保存到`run_flux.py`并使用[torchrun](https://pytorch.org/docs/stable/elastic/run.html)启动。
+
+```bash
+# 使用--nproc_per_node指定GPU数量
+torchrun --nproc_per_node=2 run_flux.py
+```
+
+推理速度降至8.20秒,相比基线快了3.21倍,使用2个NVIDIA L20 GPU。在4个L20上,推理速度为3.90秒,快了6.75倍。
+
+
+
+
+以下代码示例结合了第一块缓存和上下文并行,以实现最快的推理速度。
+
+```python
+import time
+import torch
+import torch.distributed as dist
+from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
+from diffusers.utils import export_to_video
+
+dist.init_process_group()
+
+torch.cuda.set_device(dist.get_rank())
+
+model_id = "tencent/HunyuanVideo"
+transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ revision="refs/pr/18",
+)
+pipe = HunyuanVideoPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=torch.float16,
+ revision="refs/pr/18",
+).to("cuda")
+
+from para_attn.context_parallel import init_context_parallel_mesh
+from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
+from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
+
+mesh = init_context_parallel_mesh(
+ pipe.device.type,
+)
+parallelize_pipe(
+ pipe,
+ mesh=mesh,
+)
+parallelize_vae(pipe.vae, mesh=mesh._flatten())
+
+from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
+
+apply_cache_on_pipe(pipe)
+
+# from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only
+#
+# torch._inductor.config.reorder_for_compute_comm_overlap = True
+#
+# quantize_(pipe.text_encoder, float8_weight_only())
+# quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())
+# pipe.transformer = torch.compile(
+# pipe.transformer, mode="max-autotune-no-cudagraphs",
+# )
+
+# 启用内存节省
+pipe.vae.enable_tiling()
+# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())
+# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())
+
+for i in range(2):
+ begin = time.time()
+ output = pipe(
+ prompt="A cat walks on the grass, realistic",
+ height=720,
+ width=1280,
+ num_frames=129,
+ num_inference_steps=1 if i == 0 else 30,
+ output_type="pil" if dist.get_rank() == 0 else "pt",
+ ).frames[0]
+ end = time.time()
+ if dist.get_rank() == 0:
+ if i == 0:
+ print(f"预热时间: {end - begin:.2f}s")
+ else:
+ print(f"时间: {end - begin:.2f}s")
+
+if dist.get_rank() == 0:
+ print("保存视频到 hunyuan_video.mp4")
+ export_to_video(output, "hunyuan_video.mp4", fps=15)
+
+dist.destroy_process_group()
+```
+
+保存到 `run_hunyuan_video.py` 并使用 [torchrun](https://pytorch.org/docs/stable/elastic/run.html) 启动。
+
+```bash
+# 使用 --nproc_per_node 指定 GPU 数量
+torchrun --nproc_per_node=8 run_hunyuan_video.py
+```
+
+推理速度降低到 649.23 秒,相比基线快 5.66 倍,使用 8 个 NVIDIA L20 GPU。
+
+
+
+
+## 基准测试
+
+
+
+
+| GPU 类型 | GPU 数量 | 优化 | 墙钟时间 (s) | 加速比 |
+| - | - | - | - | - |
+| NVIDIA L20 | 1 | 基线 | 26.36 | 1.00x |
+| NVIDIA L20 | 1 | FBCache (rdt=0.08) | 17.01 | 1.55x |
+| NVIDIA L20 | 1 | FP8 DQ | 13.40 | 1.96x |
+| NVIDIA L20 | 1 | FBCache (rdt=0.12) + FP8 DQ | 7.56 | 3.48x |
+| NVIDIA L20 | 2 | FBCache (rdt=0.12) + FP8 DQ + CP | 4.92 | 5.35x |
+| NVIDIA L20 | 4 | FBCache (rdt=0.12) + FP8 DQ + CP | 3.90 | 6.75x |
+
+
+
+
+| GPU 类型 | GPU 数量 | 优化 | 墙钟时间 (s) | 加速比 |
+| - | - | - | - | - |
+| NVIDIA L20 | 1 | 基线 | 3675.71 | 1.00x |
+| NVIDIA
+L20 | 1 | FBCache | 2271.06 | 1.62x |
+| NVIDIA L20 | 2 | FBCache + CP | 1132.90 | 3.24x |
+| NVIDIA L20 | 4 | FBCache + CP | 718.15 | 5.12x |
+| NVIDIA L20 | 8 | FBCache + CP | 649.23 | 5.66x |
+
+
+
\ No newline at end of file
diff --git a/docs/source/zh/optimization/pruna.md b/docs/source/zh/optimization/pruna.md
new file mode 100644
index 000000000000..31cc3d52fa25
--- /dev/null
+++ b/docs/source/zh/optimization/pruna.md
@@ -0,0 +1,184 @@
+# Pruna
+
+[Pruna](https://github.com/PrunaAI/pruna) 是一个模型优化框架,提供多种优化方法——量化、剪枝、缓存、编译——以加速推理并减少内存使用。以下是优化方法的概览。
+
+| 技术 | 描述 | 速度 | 内存 | 质量 |
+|------------|---------------------------------------------------------------------------------------|:----:|:----:|:----:|
+| `batcher` | 将多个输入分组在一起同时处理,提高计算效率并减少处理时间。 | ✅ | ❌ | ➖ |
+| `cacher` | 存储计算的中间结果以加速后续操作。 | ✅ | ➖ | ➖ |
+| `compiler` | 为特定硬件优化模型指令。 | ✅ | ➖ | ➖ |
+| `distiller`| 训练一个更小、更简单的模型来模仿一个更大、更复杂的模型。 | ✅ | ✅ | ❌ |
+| `quantizer`| 降低权重和激活的精度,减少内存需求。 | ✅ | ✅ | ❌ |
+| `pruner` | 移除不重要或冗余的连接和神经元,产生一个更稀疏、更高效的网络。 | ✅ | ✅ | ❌ |
+| `recoverer`| 在压缩后恢复模型的性能。 | ➖ | ➖ | ✅ |
+| `factorizer`| 将多个小矩阵乘法批处理为一个大型融合操作。 | ✅ | ➖ | ➖ |
+| `enhancer` | 通过应用后处理算法(如去噪或上采样)来增强模型输出。 | ❌ | - | ✅ |
+
+✅ (改进), ➖ (大致相同), ❌ (恶化)
+
+在 [Pruna 文档](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms) 中探索所有优化方法。
+
+## 安装
+
+使用以下命令安装 Pruna。
+
+```bash
+pip install pruna
+```
+
+## 优化 Diffusers 模型
+
+Diffusers 模型支持广泛的优化算法,如下所示。
+
+
+
+
+
+下面的示例使用 factorizer、compiler 和 cacher 算法的组合优化 [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)。这种组合将推理速度加速高达 4.2 倍,并将峰值 GPU 内存使用从 34.7GB 减少到 28.0GB,同时几乎保持相同的输出质量。
+
+> [!TIP]
+> 参考 [Pruna 优化](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) 文档以了解更多关于该操作的信息。
+本示例中使用的优化技术。
+
+
+
+
+
+首先定义一个包含要使用的优化算法的`SmashConfig`。要优化模型,将管道和`SmashConfig`用`smash`包装,然后像往常一样使用管道进行推理。
+
+```python
+import torch
+from diffusers import FluxPipeline
+
+from pruna import PrunaModel, SmashConfig, smash
+
+# 加载模型
+# 使用小GPU内存尝试segmind/Segmind-Vega或black-forest-labs/FLUX.1-schnell
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16
+).to("cuda")
+
+# 定义配置
+smash_config = SmashConfig()
+smash_config["factorizer"] = "qkv_diffusers"
+smash_config["compiler"] = "torch_compile"
+smash_config["torch_compile_target"] = "module_list"
+smash_config["cacher"] = "fora"
+smash_config["fora_interval"] = 2
+
+# 为了获得最佳速度结果,可以添加这些配置
+# 但它们会将预热时间从1.5分钟增加到10分钟
+# smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
+# smash_config["quantizer"] = "torchao"
+# smash_config["torchao_quant_type"] = "fp8dq"
+# smash_config["torchao_excluded_modules"] = "norm+embedding"
+
+# 优化模型
+smashed_pipe = smash(pipe, smash_config)
+
+# 运行模型
+smashed_pipe("a knitted purple prune").images[0]
+```
+
+
+
+
+
+优化后,我们可以使用Hugging Face Hub共享和加载优化后的模型。
+
+```python
+# 保存模型
+smashed_pipe.save_to_hub("/FLUX.1-dev-smashed")
+
+# 加载模型
+smashed_pipe = PrunaModel.from_hub("/FLUX.1-dev-smashed")
+```
+
+## 评估和基准测试Diffusers模型
+
+Pruna提供了[EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)来评估优化后模型的质量。
+
+我们可以定义我们关心的指标,如总时间和吞吐量,以及要评估的数据集。我们可以定义一个模型并将其传递给`EvaluationAgent`。
+
+
+
+
+我们可以通过使用`EvaluationAgent`加载和评估优化后的模型,并将其传递给`Task`。
+
+```python
+import torch
+from diffusers import FluxPipeline
+
+from pruna import PrunaModel
+from pruna.data.pruna_datamodule import PrunaDataModule
+from pruna.evaluation.evaluation_agent import EvaluationAgent
+from pruna.evaluation.metrics import (
+ ThroughputMetric,
+ TorchMetricWrapper,
+ TotalTimeMetric,
+)
+from pruna.evaluation.task import Task
+
+# define the device
+device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+
+# 加载模型
+# 使用小GPU内存尝试 PrunaAI/Segmind-Vega-smashed 或 PrunaAI/FLUX.1-dev-smashed
+smashed_pipe = PrunaModel.from_hub("PrunaAI/FLUX.1-dev-smashed")
+
+# 定义指标
+metrics = [
+ TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),
+ ThroughputMetric(n_iterations=20, n_warmup_iterations=5),
+ TorchMetricWrapper("clip"),
+]
+
+# 定义数据模块
+datamodule = PrunaDataModule.from_string("LAION256")
+datamodule.limit_datasets(10)
+
+# 定义任务和评估代理
+task = Task(metrics, datamodule=datamodule, device=device)
+eval_agent = EvaluationAgent(task)
+
+# 评估优化模型并卸载到CPU
+smashed_pipe.move_to_device(device)
+smashed_pipe_results = eval_agent.evaluate(smashed_pipe)
+smashed_pipe.move_to_device("cpu")
+```
+
+
+
+
+除了比较优化模型与基础模型,您还可以评估独立的 `diffusers` 模型。这在您想评估模型性能而不考虑优化时非常有用。我们可以通过使用 `PrunaModel` 包装器并运行 `EvaluationAgent` 来实现。
+
+```python
+import torch
+from diffusers import FluxPipeline
+
+from pruna import PrunaModel
+
+# 加载模型
+# 使用小GPU内存尝试 PrunaAI/Segmind-Vega-smashed 或 PrunaAI/FLUX.1-dev-smashed
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16
+).to("cpu")
+wrapped_pipe = PrunaModel(model=pipe)
+```
+
+
+
+
+现在您已经了解了如何优化和评估您的模型,可以开始使用 Pruna 来优化您自己的模型了。幸运的是,我们有许多示例来帮助您入门。
+
+> [!TIP]
+> 有关基准测试 Flux 的更多详细信息,请查看 [宣布 FLUX-Juiced:最快的图像生成端点(快 2.6 倍)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) 博客文章和 [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) 空间。
+
+## 参考
+
+- [Pruna](https://github.com/pruna-ai/pruna)
+- [Pruna 优化](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)
+- [Pruna 评估](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)
+- [Pruna 教程](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)
\ No newline at end of file
diff --git a/docs/source/zh/optimization/speed-memory-optims.md b/docs/source/zh/optimization/speed-memory-optims.md
new file mode 100644
index 000000000000..48f1483d3e94
--- /dev/null
+++ b/docs/source/zh/optimization/speed-memory-optims.md
@@ -0,0 +1,200 @@
+
+
+# 编译和卸载量化模型
+
+优化模型通常涉及[推理速度](./fp16)和[内存使用](./memory)之间的权衡。例如,虽然[缓存](./cache)可以提高推理速度,但它也会增加内存消耗,因为它需要存储中间注意力层的输出。一种更平衡的优化策略结合了量化模型、[torch.compile](./fp16#torchcompile) 和各种[卸载方法](./memory#offloading)。
+
+> [!TIP]
+> 查看 [torch.compile](./fp16#torchcompile) 指南以了解更多关于编译以及如何在此处应用的信息。例如,区域编译可以显著减少编译时间,而不会放弃任何加速。
+
+对于图像生成,结合量化和[模型卸载](./memory#model-offloading)通常可以在质量、速度和内存之间提供最佳权衡。组卸载对于图像生成效果不佳,因为如果计算内核更快完成,通常不可能*完全*重叠数据传输。这会导致 CPU 和 GPU 之间的一些通信开销。
+
+对于视频生成,结合量化和[组卸载](./memory#group-offloading)往往更好,因为视频模型更受计算限制。
+
+下表提供了优化策略组合及其对 Flux 延迟和内存使用的影响的比较。
+
+| 组合 | 延迟 (s) | 内存使用 (GB) |
+|---|---|---|
+| 量化 | 32.602 | 14.9453 |
+| 量化, torch.compile | 25.847 | 14.9448 |
+| 量化, torch.compile, 模型 CPU 卸载 | 32.312 | 12.2369 |
+这些结果是在 Flux 上使用 RTX 4090 进行基准测试的。transformer 和 text_encoder 组件已量化。如果您有兴趣评估自己的模型,请参考[基准测试脚本](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d)。
+
+本指南将向您展示如何使用 [bitsandbytes](../quantization/bitsandbytes#torchcompile) 编译和卸载量化模型。确保您正在使用 [PyTorch nightly](https://pytorch.org/get-started/locally/) 和最新版本的 bitsandbytes。
+
+```bash
+pip install -U bitsandbytes
+```
+
+## 量化和 torch.compile
+
+首先通过[量化](../quantization/overview)模型来减少存储所需的内存,并[编译](./fp16#torchcompile)它以加速推理。
+
+配置 [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `capture_dynamic_output_shape_ops = True` 以在编译 bitsandbytes 模型时处理动态输出。
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+
+torch._dynamo.config.capture_dynamic_output_shape_ops = True
+
+# 量化
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder_2"],
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# 编译
+pipeline.transformer.to(memory_format=torch.channels_last)
+pipeline.transformer.compile(mode="max-autotune", fullgraph=True)
+pipeline("""
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+).images[0]
+```
+
+## 量化、torch.compile 和卸载
+
+除了量化和 torch.compile,如果您需要进一步减少内存使用,可以尝试卸载。卸载根据需要将各种层或模型组件从 CPU 移动到 GPU 进行计算。
+
+在卸载期间配置 [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `cache_size_limit` 以避免过多的重新编译,并设置 `capture_dynamic_output_shape_ops = True` 以在编译 bitsandbytes 模型时处理动态输出。
+
+
+
+
+[模型 CPU 卸载](./memory#model-offloading) 将单个管道组件(如 transformer 模型)在需要计算时移动到 GPU。否则,它会被卸载到 CPU。
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+
+torch._dynamo.config.cache_size_limit = 1000
+torch._dynamo.config.capture_dynamic_output_shape_ops = True
+
+# 量化
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder_2"],
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# 模型 CPU 卸载
+pipeline.enable_model_cpu_offload()
+
+# 编译
+pipeline.transformer.compile()
+pipeline(
+ "cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
+).images[0]
+```
+
+
+
+
+[组卸载](./memory#group-offloading) 将单个管道组件(如变换器模型)的内部层移动到 GPU 进行计算,并在不需要时将其卸载。同时,它使用 [CUDA 流](./memory#cuda-stream) 功能来预取下一层以执行。
+
+通过重叠计算和数据传输,它比模型 CPU 卸载更快,同时还能节省内存。
+
+```py
+# pip install ftfy
+import torch
+from diffusers import AutoModel, DiffusionPipeline
+from diffusers.hooks import apply_group_offloading
+from diffusers.utils import export_to_video
+from diffusers.quantizers import PipelineQuantizationConfig
+from transformers import UMT5EncoderModel
+
+torch._dynamo.config.cache_size_limit = 1000
+torch._dynamo.config.capture_dynamic_output_shape_ops = True
+
+# 量化
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder"],
+)
+
+text_encoder = UMT5EncoderModel.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-14B-Diffusers",
+ quantization_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# 组卸载
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+
+pipeline.transformer.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True,
+ non_blocking=True
+)
+pipeline.vae.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True,
+ non_blocking=True
+)
+apply_group_offloading(
+ pipeline.text_encoder,
+ onload_device=onload_device,
+ offload_type="leaf_level",
+ use_stream=True,
+ non_blocking=True
+)
+
+# 编译
+pipeline.transformer.compile()
+
+prompt = """
+The camera rushes from far to near in a low-angle shot,
+revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
+"""
+negative_prompt = """
+Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
+low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
+misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
+"""
+
+output = pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_frames=81,
+ guidance_scale=5.0,
+).frames[0]
+export_to_video(output, "output.mp4", fps=16)
+```
+
+
+
\ No newline at end of file
diff --git a/docs/source/zh/optimization/tgate.md b/docs/source/zh/optimization/tgate.md
new file mode 100644
index 000000000000..f15b9bde8413
--- /dev/null
+++ b/docs/source/zh/optimization/tgate.md
@@ -0,0 +1,182 @@
+# T-GATE
+
+[T-GATE](https://github.com/HaozheLiu-ST/T-GATE/tree/main) 通过跳过交叉注意力计算一旦收敛,加速了 [Stable Diffusion](../api/pipelines/stable_diffusion/overview)、[PixArt](../api/pipelines/pixart) 和 [Latency Consistency Model](../api/pipelines/latent_consistency_models.md) 管道的推理。此方法不需要任何额外训练,可以将推理速度提高 10-50%。T-GATE 还与 [DeepCache](./deepcache) 等其他优化方法兼容。
+
+开始之前,请确保安装 T-GATE。
+
+```bash
+pip install tgate
+pip install -U torch diffusers transformers accelerate DeepCache
+```
+
+要使用 T-GATE 与管道,您需要使用其对应的加载器。
+
+| 管道 | T-GATE 加载器 |
+|---|---|
+| PixArt | TgatePixArtLoader |
+| Stable Diffusion XL | TgateSDXLLoader |
+| Stable Diffusion XL + DeepCache | TgateSDXLDeepCacheLoader |
+| Stable Diffusion | TgateSDLoader |
+| Stable Diffusion + DeepCache | TgateSDDeepCacheLoader |
+
+接下来,创建一个 `TgateLoader`,包含管道、门限步骤(停止计算交叉注意力的时间步)和推理步骤数。然后在管道上调用 `tgate` 方法,提供提示、门限步骤和推理步骤数。
+
+让我们看看如何为几个不同的管道启用此功能。
+
+
+
+
+使用 T-GATE 加速 `PixArtAlphaPipeline`:
+
+```py
+import torch
+from diffusers import PixArtAlphaPipeline
+from tgate import TgatePixArtLoader
+
+pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+
+gate_step = 8
+inference_step = 25
+pipe = TgatePixArtLoader(
+ pipe,
+ gate_step=gate_step,
+ num_inference_steps=inference_step,
+).to("cuda")
+
+image = pipe.tgate(
+ "An alpaca made of colorful building blocks, cyberpunk.",
+ gate_step=gate_step,
+ num_inference_steps=inference_step,
+).images[0]
+```
+
+
+
+使用 T-GATE 加速 `StableDiffusionXLPipeline`:
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+from diffusers import DPMSolverMultistepScheduler
+from tgate import TgateSDXLLoader
+
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True,
+)
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+
+gate_step = 10
+inference_step = 25
+pipe = TgateSDXLLoader(
+ pipe,
+ gate_step=gate_step,
+ num_inference_steps=inference_step,
+).to("cuda")
+
+image = pipe.tgate(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+ gate_step=gate_step,
+ num_inference_steps=inference_step
+).images[0]
+```
+
+
+
+使用 [DeepCache](https://github.co 加速 `StableDiffusionXLPipeline`
+m/horseee/DeepCache) 和 T-GATE:
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+from diffusers import DPMSolverMultistepScheduler
+from tgate import TgateSDXLDeepCacheLoader
+
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True,
+)
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+
+gate_step = 10
+inference_step = 25
+pipe = TgateSDXLDeepCacheLoader(
+ pipe,
+ cache_interval=3,
+ cache_branch_id=0,
+).to("cuda")
+
+image = pipe.tgate(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+ gate_step=gate_step,
+ num_inference_steps=inference_step
+).images[0]
+```
+
+
+
+使用 T-GATE 加速 `latent-consistency/lcm-sdxl`:
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+from diffusers import UNet2DConditionModel, LCMScheduler
+from diffusers import DPMSolverMultistepScheduler
+from tgate import TgateSDXLLoader
+
+unet = UNet2DConditionModel.from_pretrained(
+ "latent-consistency/lcm-sdxl",
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ unet=unet,
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+gate_step = 1
+inference_step = 4
+pipe = TgateSDXLLoader(
+ pipe,
+ gate_step=gate_step,
+ num_inference_steps=inference_step,
+ lcm=True
+).to("cuda")
+
+image = pipe.tgate(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+ gate_step=gate_step,
+ num_inference_steps=inference_step
+).images[0]
+```
+
+
+
+T-GATE 还支持 [`StableDiffusionPipeline`] 和 [PixArt-alpha/PixArt-LCM-XL-2-1024-MS](https://hf.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS)。
+
+## 基准测试
+| 模型 | MACs | 参数 | 延迟 | 零样本 10K-FID on MS-COCO |
+|-----------------------|----------|-----------|---------|---------------------------|
+| SD-1.5 | 16.938T | 859.520M | 7.032s | 23.927 |
+| SD-1.5 w/ T-GATE | 9.875T | 815.557M | 4.313s | 20.789 |
+| SD-2.1 | 38.041T | 865.785M | 16.121s | 22.609 |
+| SD-2.1 w/ T-GATE | 22.208T | 815.433 M | 9.878s | 19.940 |
+| SD-XL | 149.438T | 2.570B | 53.187s | 24.628 |
+| SD-XL w/ T-GATE | 84.438T | 2.024B | 27.932s | 22.738 |
+| Pixart-Alpha | 107.031T | 611.350M | 61.502s | 38.669 |
+| Pixart-Alpha w/ T-GATE | 65.318T | 462.585M | 37.867s | 35.825 |
+| DeepCache (SD-XL) | 57.888T | - | 19.931s | 23.755 |
+| DeepCache 配合 T-GATE | 43.868T | - | 14.666秒 | 23.999 |
+| LCM (SD-XL) | 11.955T | 2.570B | 3.805秒 | 25.044 |
+| LCM 配合 T-GATE | 11.171T | 2.024B | 3.533秒 | 25.028 |
+| LCM (Pixart-Alpha) | 8.563T | 611.350M | 4.733秒 | 36.086 |
+| LCM 配合 T-GATE | 7.623T | 462.585M | 4.543秒 | 37.048 |
+
+延迟测试基于 NVIDIA 1080TI,MACs 和 Params 使用 [calflops](https://github.com/MrYxJ/calculate-flops.pytorch) 计算,FID 使用 [PytorchFID](https://github.com/mseitzer/pytorch-fid) 计算。
\ No newline at end of file
diff --git a/docs/source/zh/optimization/tome.md b/docs/source/zh/optimization/tome.md
new file mode 100644
index 000000000000..732777c5586c
--- /dev/null
+++ b/docs/source/zh/optimization/tome.md
@@ -0,0 +1,90 @@
+
+
+# 令牌合并
+
+[令牌合并](https://huggingface.co/papers/2303.17604)(ToMe)在基于 Transformer 的网络的前向传递中逐步合并冗余令牌/补丁,这可以加速 [`StableDiffusionPipeline`] 的推理延迟。
+
+从 `pip` 安装 ToMe:
+
+```bash
+pip install tomesd
+```
+
+您可以使用 [`tomesd`](https://github.com/dbolya/tomesd) 库中的 [`apply_patch`](https://github.com/dbolya/tomesd?tab=readme-ov-file#usage) 函数:
+
+```diff
+ from diffusers import StableDiffusionPipeline
+ import torch
+ import tomesd
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
+ ).to("cuda")
++ tomesd.apply_patch(pipeline, ratio=0.5)
+
+ image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
+```
+
+`apply_patch` 函数公开了多个[参数](https://github.com/dbolya/tomesd#usage),以帮助在管道推理速度和生成令牌的质量之间取得平衡。最重要的参数是 `ratio`,它控制在前向传递期间合并的令牌数量。
+
+如[论文](https://huggingface.co/papers/2303.17604)中所述,ToMe 可以在显著提升推理速度的同时,很大程度上保留生成图像的质量。通过增加 `ratio`,您可以进一步加速推理,但代价是图像质量有所下降。
+
+为了测试生成图像的质量,我们从 [Parti Prompts](https://parti.research.google/) 中采样了一些提示,并使用 [`StableDiffusionPipeline`] 进行了推理,设置如下:
+
+
+
+
+
+我们没有注意到生成样本的质量有任何显著下降,您可以在此 [WandB 报告](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=)中查看生成的样本。如果您有兴趣重现此实验,请使用此[脚本](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd)。
+
+## 基准测试
+
+我们还在启用 [xFormers](https://huggingface.co/docs/diffusers/optimization/xformers) 的情况下,对 [`StableDiffusionPipeline`] 上 `tomesd` 的影响进行了基准测试,涵盖了多个图像分辨率。结果
+结果是从以下开发环境中的A100和V100 GPU获得的:
+
+```bash
+- `diffusers` 版本:0.15.1
+- Python 版本:3.8.16
+- PyTorch 版本(GPU?):1.13.1+cu116 (True)
+- Huggingface_hub 版本:0.13.2
+- Transformers 版本:4.27.2
+- Accelerate 版本:0.18.0
+- xFormers 版本:0.0.16
+- tomesd 版本:0.1.2
+```
+
+要重现此基准测试,请随意使用此[脚本](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335)。结果以秒为单位报告,并且在适用的情况下,我们报告了使用ToMe和ToMe + xFormers时相对于原始管道的加速百分比。
+
+| **GPU** | **分辨率** | **批处理大小** | **原始** | **ToMe** | **ToMe + xFormers** |
+|----------|----------------|----------------|-------------|----------------|---------------------|
+| **A100** | 512 | 10 | 6.88 | 5.26 (+23.55%) | 4.69 (+31.83%) |
+| | 768 | 10 | OOM | 14.71 | 11 |
+| | | 8 | OOM | 11.56 | 8.84 |
+| | | 4 | OOM | 5.98 | 4.66 |
+| | | 2 | 4.99 | 3.24 (+35.07%) | 2.1 (+37.88%) |
+| | | 1 | 3.29 | 2.24 (+31.91%) | 2.03 (+38.3%) |
+| | 1024 | 10 | OOM | OOM | OOM |
+| | | 8 | OOM | OOM | OOM |
+| | | 4 | OOM | 12.51 | 9.09 |
+| | | 2 | OOM | 6.52 | 4.96 |
+| | | 1 | 6.4 | 3.61 (+43.59%) | 2.81 (+56.09%) |
+| **V100** | 512 | 10 | OOM | 10.03 | 9.29 |
+| | | 8 | OOM | 8.05 | 7.47 |
+| | | 4 | 5.7 | 4.3 (+24.56%) | 3.98 (+30.18%) |
+| | | 2 | 3.14 | 2.43 (+22.61%) | 2.27 (+27.71%) |
+| | | 1 | 1.88 | 1.57 (+16.49%) | 1.57 (+16.49%) |
+| | 768 | 10 | OOM | OOM | 23.67 |
+| | | 8 | OOM | OOM | 18.81 |
+| | | 4 | OOM | 11.81 | 9.7 |
+| | | 2 | OOM | 6.27 | 5.2 |
+| | | 1 | 5.43 | 3.38 (+37.75%) | 2.82 (+48.07%) |
+| | 1024 | 10 | OOM |
+如上表所示,`tomesd` 带来的加速效果在更大的图像分辨率下变得更加明显。有趣的是,使用 `tomesd` 可以在更高分辨率如 1024x1024 上运行管道。您可能还可以通过 [`torch.compile`](fp16#torchcompile) 进一步加速推理。
\ No newline at end of file
diff --git a/docs/source/zh/optimization/xdit.md b/docs/source/zh/optimization/xdit.md
new file mode 100644
index 000000000000..3308536d06c1
--- /dev/null
+++ b/docs/source/zh/optimization/xdit.md
@@ -0,0 +1,119 @@
+# xDiT
+
+[xDiT](https://github.com/xdit-project/xDiT) 是一个推理引擎,专为大规模并行部署扩散变换器(DiTs)而设计。xDiT 提供了一套用于扩散模型的高效并行方法,以及 GPU 内核加速。
+
+xDiT 支持四种并行方法,包括[统一序列并行](https://huggingface.co/papers/2405.07719)、[PipeFusion](https://huggingface.co/papers/2405.14430)、CFG 并行和数据并行。xDiT 中的这四种并行方法可以以混合方式配置,优化通信模式以最适合底层网络硬件。
+
+与并行化正交的优化侧重于加速单个 GPU 的性能。除了利用知名的注意力优化库外,我们还利用编译加速技术,如 torch.compile 和 onediff。
+
+xDiT 的概述如下所示。
+
+
+
+
+您可以使用以下命令安装 xDiT:
+
+```bash
+pip install xfuser
+```
+
+以下是一个使用 xDiT 加速 Diffusers 模型推理的示例。
+
+```diff
+ import torch
+ from diffusers import StableDiffusion3Pipeline
+
+ from xfuser import xFuserArgs, xDiTParallel
+ from xfuser.config import FlexibleArgumentParser
+ from xfuser.core.distributed import get_world_group
+
+ def main():
++ parser = FlexibleArgumentParser(description="xFuser Arguments")
++ args = xFuserArgs.add_cli_args(parser).parse_args()
++ engine_args = xFuserArgs.from_cli_args(args)
++ engine_config, input_config = engine_args.create_config()
+
+ local_rank = get_world_group().local_rank
+ pipe = StableDiffusion3Pipeline.from_pretrained(
+ pretrained_model_name_or_path=engine_config.model_config.model,
+ torch_dtype=torch.float16,
+ ).to(f"cuda:{local_rank}")
+
+# 在这里对管道进行任何操作
+
++ pipe = xDiTParallel(pipe, engine_config, input_config)
+
+ pipe(
+ height=input_config.height,
+ width=input_config.height,
+ prompt=input_config.prompt,
+ num_inference_steps=input_config.num_inference_steps,
+ output_type=input_config.output_type,
+ generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
+ )
+
++ if input_config.output_type == "pil":
++ pipe.save("results", "stable_diffusion_3")
+
+if __name__ == "__main__":
+ main()
+```
+
+如您所见,我们只需要使用 xDiT 中的 xFuserArgs 来获取配置参数,并将这些参数与来自 Diffusers 库的管道对象一起传递给 xDiTParallel,即可完成对 Diffusers 中特定管道的并行化。
+
+xDiT 运行时参数可以在命令行中使用 `-h` 查看,您可以参考此[使用](https://github.com/xdit-project/xDiT?tab=readme-ov-file#2-usage)示例以获取更多详细信息。
+ils。
+
+xDiT 需要使用 torchrun 启动,以支持其多节点、多 GPU 并行能力。例如,以下命令可用于 8-GPU 并行推理:
+
+```bash
+torchrun --nproc_per_node=8 ./inference.py --model models/FLUX.1-dev --data_parallel_degree 2 --ulysses_degree 2 --ring_degree 2 --prompt "A snowy mountain" "A small dog" --num_inference_steps 50
+```
+
+## 支持的模型
+
+在 xDiT 中支持 Diffusers 模型的一个子集,例如 Flux.1、Stable Diffusion 3 等。最新支持的模型可以在[这里](https://github.com/xdit-project/xDiT?tab=readme-ov-file#-supported-dits)找到。
+
+## 基准测试
+我们在不同机器上测试了各种模型,以下是一些基准数据。
+
+### Flux.1-schnell
+
+
+
+
+
+
+
+
+### Stable Diffusion 3
+
+
+
+
+
+
+
+
+### HunyuanDiT
+
+
+
+
+
+
+
+
+
+
+
+
+更详细的性能指标可以在我们的 [GitHub 页面](https://github.com/xdit-project/xDiT?tab=readme-ov-file#perf) 上找到。
+
+## 参考文献
+
+[xDiT-project](https://github.com/xdit-project/xDiT)
+
+[USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://huggingface.co/papers/2405.07719)
+
+[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://huggingface.co/papers/2405.14430)
\ No newline at end of file
diff --git a/docs/source/zh/optimization/xformers.md b/docs/source/zh/optimization/xformers.md
new file mode 100644
index 000000000000..2a3a3d8341e0
--- /dev/null
+++ b/docs/source/zh/optimization/xformers.md
@@ -0,0 +1,26 @@
+
+
+# xFormers
+
+我们推荐在推理和训练过程中使用[xFormers](https://github.com/facebookresearch/xformers)。在我们的测试中,其对注意力模块的优化能同时提升运行速度并降低内存消耗。
+
+通过`pip`安装xFormers:
+
+```bash
+pip install xformers
+```
+
+> [!TIP]
+> xFormers的`pip`安装包需要最新版本的PyTorch。如需使用旧版PyTorch,建议[从源码安装xFormers](https://github.com/facebookresearch/xformers#installing-xformers)。
+
+安装完成后,您可调用`enable_xformers_memory_efficient_attention()`来实现更快的推理速度和更低的内存占用,具体用法参见[此章节](memory#memory-efficient-attention)。
+
+> [!WARNING]
+> 根据[此问题](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)反馈,xFormers `v0.0.16`版本在某些GPU上无法用于训练(微调或DreamBooth)。如遇此问题,请按照该issue评论区指引安装开发版本。
\ No newline at end of file
diff --git a/docs/source/zh/quicktour.md b/docs/source/zh/quicktour.md
index 1cf90e787668..2b8803384f25 100644
--- a/docs/source/zh/quicktour.md
+++ b/docs/source/zh/quicktour.md
@@ -1,4 +1,4 @@
-
-
-# 有效且高效的扩散
-
-[[open-in-colab]]
-
-让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下,你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程,特别是如果你要一遍又一遍地进行推理运算。
-
-这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ,以减少推理周期之间的时间,从而使迭代速度更快。
-
-
-本教程将指导您如何通过 [`DiffusionPipeline`] 更快、更好地生成图像。
-
-
-首先,加载 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 模型:
-
-```python
-from diffusers import DiffusionPipeline
-
-model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
-```
-
-本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ,但是你可以随心所欲的想象和构造自己的提示词:
-
-```python
-prompt = "portrait photo of a old warrior chief"
-```
-
-## 速度
-
-
-
-💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !
-
-
-
-加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ,就像使用任何 PyTorch 模块一样:
-
-```python
-pipeline = pipeline.to("cuda")
-```
-
-为了确保您可以使用相同的图像并对其进行改进,使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法,然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reusing_seeds):
-
-```python
-import torch
-
-generator = torch.Generator("cuda").manual_seed(0)
-```
-
-现在,你可以生成一个图像:
-
-```python
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-
-
-
-在 T4 GPU 上,这个过程大概要30秒(如果你的 GPU 比 T4 好,可能会更快)。在默认情况下,[`DiffusionPipeline`] 使用完整的 `float32` 精度进行 50 步推理。你可以通过降低精度(如 `float16` )或者减少推理步数来加速整个过程
-
-
-让我们把模型的精度降低至 `float16` ,然后生成一张图像:
-
-```python
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
-pipeline = pipeline.to("cuda")
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-
-
-
-这一次,生成图像只花了约 11 秒,比之前快了近 3 倍!
-
-
-
-💡 我们强烈建议把 pipeline 精度降低至 `float16` , 到目前为止, 我们很少看到输出质量有任何下降。
-
-
-
-另一个选择是减少推理步数。 你可以选择一个更高效的调度器 (*scheduler*) 可以减少推理步数同时保证输出质量。您可以在 [DiffusionPipeline] 中通过调用compatibles方法找到与当前模型兼容的调度器 (*scheduler*)。
-
-```python
-pipeline.scheduler.compatibles
-[
- diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
- diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
- diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
- diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
- diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
- diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
- diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
- diffusers.schedulers.scheduling_pndm.PNDMScheduler,
- diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_ddim.DDIMScheduler,
-]
-```
-
-Stable Diffusion 模型默认使用的是 [`PNDMScheduler`] ,通常要大概50步推理, 但是像 [`DPMSolverMultistepScheduler`] 这样更高效的调度器只要大概 20 或 25 步推理. 使用 [`ConfigMixin.from_config`] 方法加载新的调度器:
-
-```python
-from diffusers import DPMSolverMultistepScheduler
-
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
-```
-
-现在将 `num_inference_steps` 设置为 20:
-
-```python
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
-image
-```
-
-
-
-
-
-太棒了!你成功把推理时间缩短到 4 秒!⚡️
-
-## 内存
-
-改善 pipeline 性能的另一个关键是减少内存的使用量,这间接意味着速度更快,因为你经常试图最大化每秒生成的图像数量。要想知道你一次可以生成多少张图片,最简单的方法是尝试不同的batch size,直到出现`OutOfMemoryError` (OOM)。
-
-创建一个函数,为每一批要生成的图像分配提示词和 `Generators` 。请务必为每个`Generator` 分配一个种子,以便于复现良好的结果。
-
-
-```python
-def get_inputs(batch_size=1):
- generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
- prompts = batch_size * [prompt]
- num_inference_steps = 20
-
- return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
-```
-
-设置 `batch_size=4` ,然后看一看我们消耗了多少内存:
-
-```python
-from diffusers.utils import make_image_grid
-
-images = pipeline(**get_inputs(batch_size=4)).images
-make_image_grid(images, 2, 2)
-```
-
-除非你有一个更大内存的GPU, 否则上述代码会返回 `OOM` 错误! 大部分内存被 cross-attention 层使用。按顺序运行可以节省大量内存,而不是在批处理中进行。你可以为 pipeline 配置 [`~DiffusionPipeline.enable_attention_slicing`] 函数:
-
-```python
-pipeline.enable_attention_slicing()
-```
-
-现在尝试把 `batch_size` 增加到 8!
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-
-
-以前你不能一批生成 4 张图片,而现在你可以在一张图片里面生成八张图片而只需要大概3.5秒!这可能是 T4 GPU 在不牺牲质量的情况运行速度最快的一种方法。
-
-## 质量
-
-在最后两节中, 你要学习如何通过 `fp16` 来优化 pipeline 的速度, 通过使用性能更高的调度器来减少推理步数, 使用注意力切片(*enabling attention slicing*)方法来节省内存。现在,你将关注的是如何提高图像的质量。
-
-### 更好的 checkpoints
-
-有个显而易见的方法是使用更好的 checkpoints。 Stable Diffusion 模型是一个很好的起点, 自正式发布以来,还发布了几个改进版本。然而, 使用更新的版本并不意味着你会得到更好的结果。你仍然需要尝试不同的 checkpoints ,并做一些研究 (例如使用 [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) 来获得更好的结果。
-
-随着该领域的发展, 有越来越多经过微调的高质量的 checkpoints 用来生成不一样的风格. 在 [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 和 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) 寻找你感兴趣的一种!
-
-### 更好的 pipeline 组件
-
-也可以尝试用新版本替换当前 pipeline 组件。让我们加载最新的 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) 从 Stability AI 加载到 pipeline, 并生成一些图像:
-
-```python
-from diffusers import AutoencoderKL
-
-vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
-pipeline.vae = vae
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-
-
-### 更好的提示词工程
-
-用于生成图像的文本非常重要, 因此被称为 *提示词工程*。 在设计提示词工程应注意如下事项:
-
-- 我想生成的图像或类似图像如何存储在互联网上?
-- 我可以提供哪些额外的细节来引导模型朝着我想要的风格生成?
-
-考虑到这一点,让我们改进提示词,以包含颜色和更高质量的细节:
-
-```python
-prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
-prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
-```
-
-使用新的提示词生成一批图像:
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-
-
-非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
-
-```python
-prompts = [
- "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
-]
-
-generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
-images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
-make_image_grid(images, 2, 2)
-```
-
-
-
-
-
-## 最后
-
-在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
-
-- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
-- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
-- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
+
+
+# 有效且高效的扩散
+
+[[open-in-colab]]
+
+让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下,你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程,特别是如果你要一遍又一遍地进行推理运算。
+
+这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ,以减少推理周期之间的时间,从而使迭代速度更快。
+
+
+本教程将指导您如何通过 [`DiffusionPipeline`] 更快、更好地生成图像。
+
+
+首先,加载 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 模型:
+
+```python
+from diffusers import DiffusionPipeline
+
+model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
+```
+
+本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ,但是你可以随心所欲的想象和构造自己的提示词:
+
+```python
+prompt = "portrait photo of a old warrior chief"
+```
+
+## 速度
+
+> [!TIP]
+> 💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !
+
+加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ,就像使用任何 PyTorch 模块一样:
+
+```python
+pipeline = pipeline.to("cuda")
+```
+
+为了确保您可以使用相同的图像并对其进行改进,使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法,然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reusing_seeds):
+
+```python
+import torch
+
+generator = torch.Generator("cuda").manual_seed(0)
+```
+
+现在,你可以生成一个图像:
+
+```python
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+在 T4 GPU 上,这个过程大概要30秒(如果你的 GPU 比 T4 好,可能会更快)。在默认情况下,[`DiffusionPipeline`] 使用完整的 `float32` 精度进行 50 步推理。你可以通过降低精度(如 `float16` )或者减少推理步数来加速整个过程
+
+
+让我们把模型的精度降低至 `float16` ,然后生成一张图像:
+
+```python
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
+pipeline = pipeline.to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+这一次,生成图像只花了约 11 秒,比之前快了近 3 倍!
+
+> [!TIP]
+> 💡 我们强烈建议把 pipeline 精度降低至 `float16` , 到目前为止, 我们很少看到输出质量有任何下降。
+
+另一个选择是减少推理步数。 你可以选择一个更高效的调度器 (*scheduler*) 可以减少推理步数同时保证输出质量。您可以在 [DiffusionPipeline] 中通过调用compatibles方法找到与当前模型兼容的调度器 (*scheduler*)。
+
+```python
+pipeline.scheduler.compatibles
+[
+ diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+ diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
+ diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
+ diffusers.schedulers.scheduling_pndm.PNDMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_ddim.DDIMScheduler,
+]
+```
+
+Stable Diffusion 模型默认使用的是 [`PNDMScheduler`] ,通常要大概50步推理, 但是像 [`DPMSolverMultistepScheduler`] 这样更高效的调度器只要大概 20 或 25 步推理. 使用 [`ConfigMixin.from_config`] 方法加载新的调度器:
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+```
+
+现在将 `num_inference_steps` 设置为 20:
+
+```python
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+
+
+
+太棒了!你成功把推理时间缩短到 4 秒!⚡️
+
+## 内存
+
+改善 pipeline 性能的另一个关键是减少内存的使用量,这间接意味着速度更快,因为你经常试图最大化每秒生成的图像数量。要想知道你一次可以生成多少张图片,最简单的方法是尝试不同的batch size,直到出现`OutOfMemoryError` (OOM)。
+
+创建一个函数,为每一批要生成的图像分配提示词和 `Generators` 。请务必为每个`Generator` 分配一个种子,以便于复现良好的结果。
+
+
+```python
+def get_inputs(batch_size=1):
+ generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
+ prompts = batch_size * [prompt]
+ num_inference_steps = 20
+
+ return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
+```
+
+设置 `batch_size=4` ,然后看一看我们消耗了多少内存:
+
+```python
+from diffusers.utils import make_image_grid
+
+images = pipeline(**get_inputs(batch_size=4)).images
+make_image_grid(images, 2, 2)
+```
+
+除非你有一个更大内存的GPU, 否则上述代码会返回 `OOM` 错误! 大部分内存被 cross-attention 层使用。按顺序运行可以节省大量内存,而不是在批处理中进行。你可以为 pipeline 配置 [`~DiffusionPipeline.enable_attention_slicing`] 函数:
+
+```python
+pipeline.enable_attention_slicing()
+```
+
+现在尝试把 `batch_size` 增加到 8!
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+以前你不能一批生成 4 张图片,而现在你可以在一张图片里面生成八张图片而只需要大概3.5秒!这可能是 T4 GPU 在不牺牲质量的情况运行速度最快的一种方法。
+
+## 质量
+
+在最后两节中, 你要学习如何通过 `fp16` 来优化 pipeline 的速度, 通过使用性能更高的调度器来减少推理步数, 使用注意力切片(*enabling attention slicing*)方法来节省内存。现在,你将关注的是如何提高图像的质量。
+
+### 更好的 checkpoints
+
+有个显而易见的方法是使用更好的 checkpoints。 Stable Diffusion 模型是一个很好的起点, 自正式发布以来,还发布了几个改进版本。然而, 使用更新的版本并不意味着你会得到更好的结果。你仍然需要尝试不同的 checkpoints ,并做一些研究 (例如使用 [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) 来获得更好的结果。
+
+随着该领域的发展, 有越来越多经过微调的高质量的 checkpoints 用来生成不一样的风格. 在 [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 和 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) 寻找你感兴趣的一种!
+
+### 更好的 pipeline 组件
+
+也可以尝试用新版本替换当前 pipeline 组件。让我们加载最新的 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) 从 Stability AI 加载到 pipeline, 并生成一些图像:
+
+```python
+from diffusers import AutoencoderKL
+
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
+pipeline.vae = vae
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+### 更好的提示词工程
+
+用于生成图像的文本非常重要, 因此被称为 *提示词工程*。 在设计提示词工程应注意如下事项:
+
+- 我想生成的图像或类似图像如何存储在互联网上?
+- 我可以提供哪些额外的细节来引导模型朝着我想要的风格生成?
+
+考虑到这一点,让我们改进提示词,以包含颜色和更高质量的细节:
+
+```python
+prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
+prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
+```
+
+使用新的提示词生成一批图像:
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+
+
+
+非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+make_image_grid(images, 2, 2)
+```
+
+
+
+
+
+## 最后
+
+在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
+
+- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
+- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
+- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
diff --git a/docs/source/zh/training/adapt_a_model.md b/docs/source/zh/training/adapt_a_model.md
new file mode 100644
index 000000000000..7dbf46ec1290
--- /dev/null
+++ b/docs/source/zh/training/adapt_a_model.md
@@ -0,0 +1,47 @@
+# 将模型适配至新任务
+
+许多扩散系统共享相同的组件架构,这使得您能够将针对某一任务预训练的模型调整适配至完全不同的新任务。
+
+本指南将展示如何通过初始化并修改预训练 [`UNet2DConditionModel`] 的架构,将文生图预训练模型改造为图像修复(inpainting)模型。
+
+## 配置 UNet2DConditionModel 参数
+
+默认情况下,[`UNet2DConditionModel`] 的[输入样本](https://huggingface.co/docs/diffusers/v0.16.0/en/api/models#diffusers.UNet2DConditionModel.in_channels)接受4个通道。例如加载 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 这样的文生图预训练模型,查看其 `in_channels` 参数值:
+
+```python
+from diffusers import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
+pipeline.unet.config["in_channels"]
+4
+```
+
+而图像修复任务需要输入样本具有9个通道。您可以在 [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting) 这样的预训练修复模型中验证此参数:
+
+```python
+from diffusers import StableDiffusionPipeline
+
+pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-inpainting", use_safetensors=True)
+pipeline.unet.config["in_channels"]
+9
+```
+
+要将文生图模型改造为修复模型,您需要将 `in_channels` 参数从4调整为9。
+
+初始化一个加载了文生图预训练权重的 [`UNet2DConditionModel`],并将 `in_channels` 设为9。由于输入通道数变化导致张量形状改变,需要设置 `ignore_mismatched_sizes=True` 和 `low_cpu_mem_usage=False` 来避免尺寸不匹配错误。
+
+```python
+from diffusers import AutoModel
+
+model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+unet = AutoModel.from_pretrained(
+ model_id,
+ subfolder="unet",
+ in_channels=9,
+ low_cpu_mem_usage=False,
+ ignore_mismatched_sizes=True,
+ use_safetensors=True,
+)
+```
+
+此时文生图模型的其他组件权重仍保持预训练状态,但UNet的输入卷积层权重(`conv_in.weight`)会随机初始化。由于这一关键变化,必须对模型进行修复任务的微调,否则模型将仅会输出噪声。
diff --git a/docs/source/zh/training/controlnet.md b/docs/source/zh/training/controlnet.md
new file mode 100644
index 000000000000..84bc3263a842
--- /dev/null
+++ b/docs/source/zh/training/controlnet.md
@@ -0,0 +1,354 @@
+
+
+# ControlNet
+
+[ControlNet](https://hf.co/papers/2302.05543) 是一种基于预训练模型的适配器架构。它通过额外输入的条件图像(如边缘检测图、深度图、人体姿态图等),实现对生成图像的精细化控制。
+
+在显存有限的GPU上训练时,建议启用训练命令中的 `gradient_checkpointing`(梯度检查点)、`gradient_accumulation_steps`(梯度累积步数)和 `mixed_precision`(混合精度)参数。还可使用 [xFormers](../optimization/xformers) 的内存高效注意力机制进一步降低显存占用。虽然JAX/Flax训练支持在TPU和GPU上高效运行,但不支持梯度检查点和xFormers。若需通过Flax加速训练,建议使用显存大于30GB的GPU。
+
+本指南将解析 [train_controlnet.py](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) 训练脚本,帮助您理解其逻辑并适配自定义需求。
+
+运行脚本前,请确保从源码安装库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+然后进入包含训练脚本的示例目录,安装所需依赖:
+
+
+
+```bash
+cd examples/controlnet
+pip install -r requirements.txt
+```
+
+
+
+若可访问TPU设备,Flax训练脚本将运行得更快!以下是在 [Google Cloud TPU VM](https://cloud.google.com/tpu/docs/run-calculation-jax) 上的配置流程。创建单个TPU v4-8虚拟机并连接:
+
+```bash
+ZONE=us-central2-b
+TPU_TYPE=v4-8
+VM_NAME=hg_flax
+
+gcloud alpha compute tpus tpu-vm create $VM_NAME \
+ --zone $ZONE \
+ --accelerator-type $TPU_TYPE \
+ --version tpu-vm-v4-base
+
+gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
+```
+
+安装JAX 0.4.5:
+
+```bash
+pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+```
+
+然后安装Flax脚本的依赖:
+
+```bash
+cd examples/controlnet
+pip install -r requirements_flax.txt
+```
+
+
+
+
+> [!TIP]
+> 🤗 Accelerate 是一个支持多GPU/TPU训练和混合精度的库,它能根据硬件环境自动配置训练方案。参阅 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 了解更多。
+
+初始化🤗 Accelerate环境:
+
+```bash
+accelerate config
+```
+
+若要创建默认配置(不进行交互式选择):
+
+```bash
+accelerate config default
+```
+
+若环境不支持交互式shell(如notebook),可使用:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+最后,如需训练自定义数据集,请参阅 [创建训练数据集](create_dataset) 指南了解数据准备方法。
+
+> [!TIP]
+> 下文重点解析脚本中的关键模块,但不会覆盖所有实现细节。如需深入了解,建议直接阅读 [脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py),如有疑问欢迎反馈。
+
+## 脚本参数
+
+训练脚本提供了丰富的可配置参数,所有参数及其说明详见 [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L231) 函数。虽然该函数已为每个参数提供默认值(如训练批大小、学习率等),但您可以通过命令行参数覆盖这些默认值。
+
+例如,使用fp16混合精度加速训练, 可使用`--mixed_precision`参数
+
+```bash
+accelerate launch train_controlnet.py \
+ --mixed_precision="fp16"
+```
+
+基础参数说明可参考 [文生图](text2image#script-parameters) 训练指南,此处重点介绍ControlNet相关参数:
+
+- `--max_train_samples`: 训练样本数量,减少该值可加快训练,但对超大数据集需配合 `--streaming` 参数使用
+- `--gradient_accumulation_steps`: 梯度累积步数,通过分步计算实现显存受限情况下的更大批次训练
+
+### Min-SNR加权策略
+
+[Min-SNR](https://huggingface.co/papers/2303.09556) 加权策略通过重新平衡损失函数加速模型收敛。虽然训练脚本支持预测 `epsilon`(噪声)或 `v_prediction`,但Min-SNR对两种预测类型均兼容。该策略仅适用于PyTorch版本,Flax训练脚本暂不支持。
+
+推荐值设为5.0:
+
+```bash
+accelerate launch train_controlnet.py \
+ --snr_gamma=5.0
+```
+
+## 训练脚本
+
+与参数说明类似,训练流程的通用解析可参考 [文生图](text2image#training-script) 指南。此处重点分析ControlNet特有的实现。
+
+脚本中的 [`make_train_dataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L582) 函数负责数据预处理,除常规的文本标注分词和图像变换外,还包含条件图像的特效处理:
+
+> [!TIP]
+> 在TPU上流式加载数据集时,🤗 Datasets库可能成为性能瓶颈(因其未针对图像数据优化)。建议考虑 [WebDataset](https://webdataset.github.io/webdataset/)、[TorchData](https://github.com/pytorch/data) 或 [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) 等高效数据格式。
+
+```py
+conditioning_image_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ]
+)
+```
+
+在 [`main()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L713) 函数中,代码会加载分词器、文本编码器、调度器和模型。此处也是ControlNet模型的加载点(支持从现有权重加载或从UNet随机初始化):
+
+```py
+if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
+else:
+ logger.info("Initializing controlnet weights from unet")
+ controlnet = ControlNetModel.from_unet(unet)
+```
+
+[优化器](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L871) 专门针对ControlNet参数进行更新:
+
+```py
+params_to_optimize = controlnet.parameters()
+optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+在 [训练循环](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L943) 中,条件文本嵌入和图像被输入到ControlNet的下采样和中层模块:
+
+```py
+encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
+
+down_block_res_samples, mid_block_res_sample = controlnet(
+ noisy_latents,
+ timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=controlnet_image,
+ return_dict=False,
+)
+```
+
+若想深入理解训练循环机制,可参阅 [理解管道、模型与调度器](../using-diffusers/write_own_pipeline) 教程,该教程详细解析了去噪过程的基本原理。
+
+## 启动训练
+
+现在可以启动训练脚本了!🚀
+
+本指南使用 [fusing/fill50k](https://huggingface.co/datasets/fusing/fill50k) 数据集,当然您也可以按照 [创建训练数据集](create_dataset) 指南准备自定义数据。
+
+设置环境变量 `MODEL_NAME` 为Hub模型ID或本地路径,`OUTPUT_DIR` 为模型保存路径。
+
+下载训练用的条件图像:
+
+```bash
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
+wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
+```
+
+根据GPU型号,可能需要启用特定优化。默认配置需要约38GB显存。若使用多GPU训练,请在 `accelerate launch` 命令中添加 `--multi_gpu` 参数。
+
+
+
+
+16GB显卡可使用bitsandbytes 8-bit优化器和梯度检查点:
+
+```py
+pip install bitsandbytes
+```
+
+训练命令添加以下参数:
+
+```bash
+accelerate launch train_controlnet.py \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+```
+
+
+
+
+12GB显卡需组合使用bitsandbytes 8-bit优化器、梯度检查点、xFormers,并将梯度置为None而非0:
+
+```bash
+accelerate launch train_controlnet.py \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+```
+
+
+
+
+8GB显卡需使用 [DeepSpeed](https://www.deepspeed.ai/) 将张量卸载到CPU或NVME:
+
+运行以下命令配置环境:
+
+```bash
+accelerate config
+```
+
+选择DeepSpeed stage 2,结合fp16混合精度和参数卸载到CPU的方案。注意这会增加约25GB内存占用。配置示例如下:
+
+```bash
+compute_environment: LOCAL_MACHINE
+deepspeed_config:
+ gradient_accumulation_steps: 4
+ offload_optimizer_device: cpu
+ offload_param_device: cpu
+ zero3_init_flag: false
+ zero_stage: 2
+distributed_type: DEEPSPEED
+```
+
+建议将优化器替换为DeepSpeed特化版 [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu),注意CUDA工具链版本需与PyTorch匹配。
+
+当前bitsandbytes与DeepSpeed存在兼容性问题。
+
+无需额外添加训练参数。
+
+
+
+
+
+
+
+```bash
+export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="path/to/save/model"
+
+accelerate launch train_controlnet.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --push_to_hub
+```
+
+
+
+
+Flax版本支持通过 `--profile_steps==5` 参数进行性能分析:
+
+```bash
+pip install tensorflow tensorboard-plugin-profile
+tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
+```
+
+在 [http://localhost:6006/#profile](http://localhost:6006/#profile) 查看分析结果。
+
+> [!WARNING]
+> 若遇到插件版本冲突,建议重新安装TensorFlow和Tensorboard。注意性能分析插件仍处实验阶段,部分视图可能不完整。`trace_viewer` 会截断超过1M的事件记录,在编译步骤分析时可能导致设备轨迹丢失。
+
+```bash
+python3 train_controlnet_flax.py \
+ --pretrained_model_name_or_path=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --dataset_name=fusing/fill50k \
+ --resolution=512 \
+ --learning_rate=1e-5 \
+ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
+ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
+ --validation_steps=1000 \
+ --train_batch_size=2 \
+ --revision="non-ema" \
+ --from_pt \
+ --report_to="wandb" \
+ --tracker_project_name=$HUB_MODEL_ID \
+ --num_train_epochs=11 \
+ --push_to_hub \
+ --hub_model_id=$HUB_MODEL_ID
+```
+
+
+
+
+训练完成后即可进行推理:
+
+```py
+from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
+from diffusers.utils import load_image
+import torch
+
+controlnet = ControlNetModel.from_pretrained("path/to/controlnet", torch_dtype=torch.float16)
+pipeline = StableDiffusionControlNetPipeline.from_pretrained(
+ "path/to/base/model", controlnet=controlnet, torch_dtype=torch.float16
+).to("cuda")
+
+control_image = load_image("./conditioning_image_1.png")
+prompt = "pale golden rod circle with old lace background"
+
+generator = torch.manual_seed(0)
+image = pipeline(prompt, num_inference_steps=20, generator=generator, image=control_image).images[0]
+image.save("./output.png")
+```
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) 是新一代文生图模型,通过添加第二文本编码器支持生成更高分辨率图像。使用 [`train_controlnet_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet_sdxl.py) 脚本可为SDXL训练ControlNet适配器。
+
+SDXL训练脚本的详细解析请参阅 [SDXL训练](sdxl) 指南。
+
+## 后续步骤
+
+恭喜完成ControlNet训练!如需进一步了解模型应用,以下指南可能有所帮助:
+
+- 学习如何 [使用ControlNet](../using-diffusers/controlnet) 进行多样化任务的推理
diff --git a/docs/source/zh/training/distributed_inference.md b/docs/source/zh/training/distributed_inference.md
new file mode 100644
index 000000000000..60297371d6be
--- /dev/null
+++ b/docs/source/zh/training/distributed_inference.md
@@ -0,0 +1,236 @@
+
+
+# 分布式推理
+
+在分布式设置中,您可以使用 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) 或 [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) 在多个 GPU 上运行推理,这对于并行生成多个提示非常有用。
+
+本指南将向您展示如何使用 🤗 Accelerate 和 PyTorch Distributed 进行分布式推理。
+
+## 🤗 Accelerate
+
+🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) 是一个旨在简化在分布式设置中训练或运行推理的库。它简化了设置分布式环境的过程,让您可以专注于您的 PyTorch 代码。
+
+首先,创建一个 Python 文件并初始化一个 [`accelerate.PartialState`] 来创建分布式环境;您的设置会自动检测,因此您无需明确定义 `rank` 或 `world_size`。将 [`DiffusionPipeline`] 移动到 `distributed_state.device` 以为每个进程分配一个 GPU。
+
+现在使用 [`~accelerate.PartialState.split_between_processes`] 实用程序作为上下文管理器,自动在进程数之间分发提示。
+
+```py
+import torch
+from accelerate import PartialState
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+)
+distributed_state = PartialState()
+pipeline.to(distributed_state.device)
+
+with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
+ result = pipeline(prompt).images[0]
+ result.save(f"result_{distributed_state.process_index}.png")
+```
+
+使用 `--num_processes` 参数指定要使用的 GPU 数量,并调用 `accelerate launch` 来运行脚本:
+
+```bash
+accelerate launch run_distributed.py --num_processes=2
+```
+
+> [!TIP]
+> 参考这个最小示例 [脚本](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) 以在多个 GPU 上运行推理。要了解更多信息,请查看 [使用 🤗 Accelerate 进行分布式推理](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 指南。
+
+## PyTorch Distributed
+
+PyTorch 支持 [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html),它启用了数据
+并行性。
+
+首先,创建一个 Python 文件并导入 `torch.distributed` 和 `torch.multiprocessing` 来设置分布式进程组,并为每个 GPU 上的推理生成进程。您还应该初始化一个 [`DiffusionPipeline`]:
+
+```py
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from diffusers import DiffusionPipeline
+
+sd = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+)
+```
+
+您需要创建一个函数来运行推理;[`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) 处理创建一个分布式环境,指定要使用的后端类型、当前进程的 `rank` 以及参与进程的数量 `world_size`。如果您在 2 个 GPU 上并行运行推理,那么 `world_size` 就是 2。
+
+将 [`DiffusionPipeline`] 移动到 `rank`,并使用 `get_rank` 为每个进程分配一个 GPU,其中每个进程处理不同的提示:
+
+```py
+def run_inference(rank, world_size):
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+ sd.to(rank)
+
+ if torch.distributed.get_rank() == 0:
+ prompt = "a dog"
+ elif torch.distributed.get_rank() == 1:
+ prompt = "a cat"
+
+ image = sd(prompt).images[0]
+ image.save(f"./{'_'.join(prompt)}.png")
+```
+
+要运行分布式推理,调用 [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) 在 `world_size` 定义的 GPU 数量上运行 `run_inference` 函数:
+
+```py
+def main():
+ world_size = 2
+ mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)
+
+
+if __name__ == "__main__":
+ main()
+```
+
+完成推理脚本后,使用 `--nproc_per_node` 参数指定要使用的 GPU 数量,并调用 `torchrun` 来运行脚本:
+
+```bash
+torchrun run_distributed.py --nproc_per_node=2
+```
+
+> [!TIP]
+> 您可以在 [`DiffusionPipeline`] 中使用 `device_map` 将其模型级组件分布在多个设备上。请参考 [设备放置](../tutorials/inference_with_big_models#device-placement) 指南了解更多信息。
+
+## 模型分片
+
+现代扩散系统,如 [Flux](../api/pipelines/flux),非常大且包含多个模型。例如,[Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) 由两个文本编码器 - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) 和 [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - 一个 [扩散变换器](../api/models/flux_transformer),以及一个 [VAE](../api/models/autoencoderkl) 组成。对于如此大的模型,在消费级 GPU 上运行推理可能具有挑战性。
+
+模型分片是一种技术,当模型无法容纳在单个 GPU 上时,将模型分布在多个 GPU 上。下面的示例假设有两个 16GB GPU 可用于推理。
+
+开始使用文本编码器计算文本嵌入。通过设置 `device_map="balanced"` 将文本编码器保持在两个GPU上。`balanced` 策略将模型均匀分布在所有可用GPU上。使用 `max_memory` 参数为每个GPU上的每个文本编码器分配最大内存量。
+
+> [!TIP]
+> **仅** 在此步骤加载文本编码器!扩散变换器和VAE在后续步骤中加载以节省内存。
+
+```py
+from diffusers import FluxPipeline
+import torch
+
+prompt = "a photo of a dog with cat-like look"
+
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ transformer=None,
+ vae=None,
+ device_map="balanced",
+ max_memory={0: "16GB", 1: "16GB"},
+ torch_dtype=torch.bfloat16
+)
+with torch.no_grad():
+ print("Encoding prompts.")
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
+ prompt=prompt, prompt_2=None, max_sequence_length=512
+ )
+```
+
+一旦文本嵌入计算完成,从GPU中移除它们以为扩散变换器腾出空间。
+
+```py
+import gc
+
+def flush():
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+ torch.cuda.reset_peak_memory_stats()
+
+del pipeline.text_encoder
+del pipeline.text_encoder_2
+del pipeline.tokenizer
+del pipeline.tokenizer_2
+del pipeline
+
+flush()
+```
+
+接下来加载扩散变换器,它有125亿参数。这次,设置 `device_map="auto"` 以自动将模型分布在两个16GB GPU上。`auto` 策略由 [Accelerate](https://hf.co/docs/accelerate/index) 支持,并作为 [大模型推理](https://hf.co/docs/accelerate/concept_guides/big_model_inference) 功能的一部分可用。它首先将模型分布在最快的设备(GPU)上,然后在需要时移动到较慢的设备如CPU和硬盘。将模型参数存储在较慢设备上的权衡是推理延迟较慢。
+
+```py
+from diffusers import AutoModel
+import torch
+
+transformer = AutoModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ device_map="auto",
+ torch_dtype=torch.bfloat16
+)
+```
+
+> [!TIP]
+> 在任何时候,您可以尝试 `print(pipeline.hf_device_map)` 来查看各种模型如何在设备上分布。这对于跟踪模型的设备放置很有用。您也可以尝试 `print(transformer.hf_device_map)` 来查看变换器模型如何在设备上分片。
+
+将变换器模型添加到管道中以进行去噪,但将其他模型级组件如文本编码器和VAE设置为 `None`,因为您还不需要它们。
+
+```py
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ text_encoder=None,
+ text_encoder_2=None,
+ tokenizer=None,
+ tokenizer_2=None,
+ vae=None,
+ transformer=transformer,
+ torch_dtype=torch.bfloat16
+)
+
+print("Running denoising.")
+height, width = 768, 1360
+latents = pipeline(
+
+
+prompt_embeds=prompt_embeds,
+pooled_prompt_embeds=pooled_prompt_embeds,
+num_inference_steps=50,
+guidance_scale=3.5,
+height=height,
+width=width,
+output_type="latent",
+).images
+```
+
+从内存中移除管道和变换器,因为它们不再需要。
+
+```py
+del pipeline.transformer
+del pipeline
+
+flush()
+```
+
+最后,使用变分自编码器(VAE)将潜在表示解码为图像。VAE通常足够小,可以在单个GPU上加载。
+
+```py
+from diffusers import AutoencoderKL
+from diffusers.image_processor import VaeImageProcessor
+import torch
+
+vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
+vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
+image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
+
+with torch.no_grad():
+ print("运行解码中。")
+ latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
+
+ image = vae.decode(latents, return_dict=False)[0]
+ image = image_processor.postprocess(image, output_type="pil")
+ image[0].save("split_transformer.png")
+```
+
+通过选择性加载和卸载在特定阶段所需的模型,并将最大模型分片到多个GPU上,可以在消费级GPU上运行大型模型的推理。
\ No newline at end of file
diff --git a/docs/source/zh/training/dreambooth.md b/docs/source/zh/training/dreambooth.md
new file mode 100644
index 000000000000..cae5e30be011
--- /dev/null
+++ b/docs/source/zh/training/dreambooth.md
@@ -0,0 +1,631 @@
+
+
+# DreamBooth
+
+[DreamBooth](https://huggingface.co/papers/2208.12242) 是一种训练技术,通过仅训练少数主题或风格的图像来更新整个扩散模型。它通过在提示中关联一个特殊词与示例图像来工作。
+
+如果您在 vRAM 有限的 GPU 上训练,应尝试在训练命令中启用 `gradient_checkpointing` 和 `mixed_precision` 参数。您还可以通过使用 [xFormers](../optimization/xformers) 的内存高效注意力来减少内存占用。JAX/Flax 训练也支持在 TPU 和 GPU 上进行高效训练,但不支持梯度检查点或 xFormers。如果您想使用 Flax 更快地训练,应拥有内存 >30GB 的 GPU。
+
+本指南将探索 [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) 脚本,帮助您更熟悉它,以及如何根据您的用例进行适配。
+
+在运行脚本之前,请确保从源代码安装库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+导航到包含训练脚本的示例文件夹,并安装脚本所需的依赖项:
+
+
+
+
+```bash
+cd examples/dreambooth
+pip install -r requirements.txt
+```
+
+
+
+
+```bash
+cd examples/dreambooth
+pip install -r requirements_flax.txt
+```
+
+
+
+
+> [!TIP]
+> 🤗 Accelerate 是一个库,用于帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
+
+初始化 🤗 Accelerate 环境:
+
+```bash
+accelerate config
+```
+
+要设置默认的 🤗 Accelerate 环境而不选择任何配置:
+
+```bash
+accelerate config default
+```
+
+或者,如果您的环境不支持交互式 shell,例如笔记本,您可以使用:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+最后,如果您想在自己的数据集上训练模型,请查看 [创建用于训练的数据集](create_dataset) 指南,了解如何创建与
+训练脚本。
+
+> [!TIP]
+> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读[脚本](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py),并告诉我们如果您有任何问题或疑虑。
+
+## 脚本参数
+
+> [!WARNING]
+> DreamBooth 对训练超参数非常敏感,容易过拟合。阅读 [使用 🧨 Diffusers 训练 Stable Diffusion 与 Dreambooth](https://huggingface.co/blog/dreambooth) 博客文章,了解针对不同主题的推荐设置,以帮助您选择合适的超参数。
+
+训练脚本提供了许多参数来自定义您的训练运行。所有参数及其描述都可以在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L228) 函数中找到。参数设置了默认值,这些默认值应该开箱即用效果不错,但如果您愿意,也可以在训练命令中设置自己的值。
+
+例如,要以 bf16 格式进行训练:
+
+```bash
+accelerate launch train_dreambooth.py \
+ --mixed_precision="bf16"
+```
+
+一些基本且重要的参数需要了解和指定:
+
+- `--pretrained_model_name_or_path`: Hub 上的模型名称或预训练模型的本地路径
+- `--instance_data_dir`: 包含训练数据集(示例图像)的文件夹路径
+- `--instance_prompt`: 包含示例图像特殊单词的文本提示
+- `--train_text_encoder`: 是否也训练文本编码器
+- `--output_dir`: 保存训练后模型的位置
+- `--push_to_hub`: 是否将训练后的模型推送到 Hub
+- `--checkpointing_steps`: 模型训练时保存检查点的频率;这在训练因某种原因中断时很有用,您可以通过在训练命令中添加 `--resume_from_checkpoint` 来从该检查点继续训练
+
+### Min-SNR 加权
+
+[Min-SNR](https://huggingface.co/papers/2303.09556) 加权策略可以通过重新平衡损失来帮助训练,以实现更快的收敛。训练脚本支持预测 `epsilon`(噪声)或 `v_prediction`,但 Min-SNR 与两种预测类型都兼容。此加权策略仅由 PyTorch 支持,在 Flax 训练脚本中不可用。
+
+添加 `--snr_gamma` 参数并将其设置为推荐值 5.0:
+
+```bash
+accelerate launch train_dreambooth.py \
+ --snr_gamma=5.0
+```
+
+### 先验保持损失
+
+先验保持损失是一种使用模型自身生成的样本来帮助它学习如何生成更多样化图像的方法。因为这些生成的样本图像属于您提供的图像相同的类别,它们帮助模型 r
+etain 它已经学到的关于类别的知识,以及它如何利用已经了解的类别信息来创建新的组合。
+
+- `--with_prior_preservation`: 是否使用先验保留损失
+- `--prior_loss_weight`: 控制先验保留损失对模型的影响程度
+- `--class_data_dir`: 包含生成的类别样本图像的文件夹路径
+- `--class_prompt`: 描述生成的样本图像类别的文本提示
+
+```bash
+accelerate launch train_dreambooth.py \
+ --with_prior_preservation \
+ --prior_loss_weight=1.0 \
+ --class_data_dir="path/to/class/images" \
+ --class_prompt="text prompt describing class"
+```
+
+### 训练文本编码器
+
+为了提高生成输出的质量,除了 UNet 之外,您还可以训练文本编码器。这需要额外的内存,并且您需要一个至少有 24GB 显存的 GPU。如果您拥有必要的硬件,那么训练文本编码器会产生更好的结果,尤其是在生成面部图像时。通过以下方式启用此选项:
+
+```bash
+accelerate launch train_dreambooth.py \
+ --train_text_encoder
+```
+
+## 训练脚本
+
+DreamBooth 附带了自己的数据集类:
+
+- [`DreamBoothDataset`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L604): 预处理图像和类别图像,并对提示进行分词以用于训练
+- [`PromptDataset`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L738): 生成提示嵌入以生成类别图像
+
+如果您启用了[先验保留损失](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L842),类别图像在此处生成:
+
+```py
+sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+sample_dataloader = accelerator.prepare(sample_dataloader)
+pipeline.to(accelerator.device)
+
+for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+):
+ images = pipeline(example["prompt"]).images
+```
+
+接下来是 [`main()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L799) 函数,它处理设置训练数据集和训练循环本身。脚本加载 [tokenizer](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L898)、[scheduler 和 models](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L912C1-L912C1):
+
+```py
+# Load the tokenizer
+if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+# 加载调度器和模型
+noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+)
+
+if model_has_vae(args):
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
+ )
+else:
+ vae = None
+
+unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+)
+```
+
+然后,是时候[创建训练数据集](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L1073)和从`DreamBoothDataset`创建DataLoader:
+
+```py
+train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+)
+
+train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+)
+```
+
+最后,[训练循环](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L1151)处理剩余步骤,例如将图像转换为潜在空间、向输入添加噪声、预测噪声残差和计算损失。
+
+如果您想了解更多关于训练循环的工作原理,请查看[理解管道、模型和调度器](../using-diffusers/write_own_pipeline)教程,该教程分解了去噪过程的基本模式。
+
+## 启动脚本
+
+您现在准备好启动训练脚本了!🚀
+
+对于本指南,您将下载一些[狗的图片](https://huggingface.co/datasets/diffusers/dog-example)的图像并将它们存储在一个目录中。但请记住,您可以根据需要创建和使用自己的数据集(请参阅[创建用于训练的数据集](create_dataset)指南)。
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog"
+snapshot_download(
+ "diffusers/dog-example",
+ local_dir=local_dir,
+ repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+设置环境变量 `MODEL_NAME` 为 Hub 上的模型 ID 或本地模型路径,`INSTANCE_DIR` 为您刚刚下载狗图像的路径,`OUTPUT_DIR` 为您想保存模型的位置。您将使用 `sks` 作为特殊词来绑定训练。
+
+如果您有兴趣跟随训练过程,可以定期保存生成的图像作为训练进度。将以下参数添加到训练命令中:
+
+```bash
+--validation_prompt="a photo of a sks dog"
+--num_validation_images=4
+--validation_steps=100
+```
+
+在启动脚本之前,还有一件事!根据您拥有的 GPU,您可能需要启用某些优化来训练 DreamBooth。
+
+
+
+
+在 16GB GPU 上,您可以使用 bitsandbytes 8 位优化器和梯度检查点来帮助训练 DreamBooth 模型。安装 bitsandbytes:
+
+```py
+pip install bitsandbytes
+```
+
+然后,将以下参数添加到您的训练命令中:
+
+```bash
+accelerate launch train_dreambooth.py \
+ --gradient_checkpointing \
+ --use_8bit_adam \
+```
+
+
+
+
+在 12GB GPU 上,您需要 bitsandbytes 8 位优化器、梯度检查点、xFormers,并将梯度设置为 `None` 而不是零以减少内存使用。
+
+```bash
+accelerate launch train_dreambooth.py \
+ --use_8bit_adam \
+ --gradient_checkpointing \
+ --enable_xformers_memory_efficient_attention \
+ --set_grads_to_none \
+```
+
+
+
+
+在 8GB GPU 上,您需要 [DeepSpeed](https://www.deepspeed.ai/) 将一些张量从 vRAM 卸载到 CPU 或 NVME,以便在更少的 GPU 内存下进行训练。
+
+运行以下命令来配置您的 🤗 Accelerate 环境:
+
+```bash
+accelerate config
+```
+
+在配置过程中,确认您想使用 DeepSpeed。现在,通过结合 DeepSpeed 阶段 2、fp16 混合精度以及将模型参数和优化器状态卸载到 CPU,应该可以在低于 8GB vRAM 的情况下进行训练。缺点是这需要更多的系统 RAM(约 25 GB)。有关更多配置选项,请参阅 [DeepSpeed 文档](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)。
+
+您还应将默认的 Adam 优化器更改为 DeepSpeed 的优化版本 [`deepspeed.ops.adam.DeepSpeedCPUAdam`](https://deepspeed.readthedocs.io/en/latest/optimizers.html#adam-cpu) 以获得显著的速度提升。启用 `DeepSpeedCPUAdam` 要求您的系统 CUDA 工具链版本与 PyTorch 安装的版本相同。
+
+目前,bitsandbytes 8 位优化器似乎与 DeepSpeed 不兼容。
+
+就是这样!您不需要向训练命令添加任何额外参数。
+
+
+
+
+
+
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export INSTANCE_DIR="./dog"
+export OUTPUT_DIR="path_to_
+saved_model"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=400 \
+ --push_to_hub
+```
+
+
+
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export INSTANCE_DIR="./dog"
+export OUTPUT_DIR="path-to-save-model"
+
+python train_dreambooth_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --max_train_steps=400 \
+ --push_to_hub
+```
+
+
+
+
+训练完成后,您可以使用新训练的模型进行推理!
+
+> [!TIP]
+> 等不及在训练完成前就尝试您的模型进行推理?🤭 请确保安装了最新版本的 🤗 Accelerate。
+>
+> ```py
+> from diffusers import DiffusionPipeline, UNet2DConditionModel
+> from transformers import CLIPTextModel
+> import torch
+>
+> unet = UNet2DConditionModel.from_pretrained("path/to/model/checkpoint-100/unet")
+>
+> # 如果您使用了 `--args.train_text_encoder` 进行训练,请确保也加载文本编码器
+> text_encoder = CLIPTextModel.from_pretrained("path/to/model/checkpoint-100/checkpoint-100/text_encoder")
+>
+> pipeline = DiffusionPipeline.from_pretrained(
+> "stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet, text_encoder=text_encoder, dtype=torch.float16,
+> ).to("cuda")
+>
+> image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
+> image.save("dog-bucket.png")
+> ```
+
+
+
+
+```py
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("path_to_saved_model", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
+image.save("dog-bucket.png")
+```
+
+
+
+
+```py
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline
+
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("path-to-your-trained-model", dtype=jax.numpy.bfloat16)
+
+prompt = "A photo of sks dog in a bucket"
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 50
+
+num_samples = jax.device_count()
+prompt = num_samples * [prompt]
+prompt_ids = pipeline.prepare_inputs(prompt)
+
+# 分片输入和随机数生成器
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_
+steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+image.save("dog-bucket.png")
+```
+
+
+
+
+## LoRA
+
+LoRA 是一种训练技术,可显著减少可训练参数的数量。因此,训练速度更快,并且更容易存储生成的权重,因为它们小得多(约 100MB)。使用 [train_dreambooth_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py) 脚本通过 LoRA 进行训练。
+
+LoRA 训练脚本在 [LoRA 训练](lora) 指南中有更详细的讨论。
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) 是一个强大的文本到图像模型,可生成高分辨率图像,并在其架构中添加了第二个文本编码器。使用 [train_dreambooth_lora_sdxl.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_sdxl.py) 脚本通过 LoRA 训练 SDXL 模型。
+
+SDXL 训练脚本在 [SDXL 训练](sdxl) 指南中有更详细的讨论。
+
+## DeepFloyd IF
+
+DeepFloyd IF 是一个级联像素扩散模型,包含三个阶段。第一阶段生成基础图像,第二和第三阶段逐步将基础图像放大为高分辨率 1024x1024 图像。使用 [train_dreambooth_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py) 或 [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) 脚本通过 LoRA 或完整模型训练 DeepFloyd IF 模型。
+
+DeepFloyd IF 使用预测方差,但 Diffusers 训练脚本使用预测误差,因此训练的 DeepFloyd IF 模型被切换到固定方差调度。训练脚本将为您更新完全训练模型的调度器配置。但是,当您加载保存的 LoRA 权重时,还必须更新管道的调度器配置。
+
+```py
+from diffusers import DiffusionPipeline
+
+pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", use_safetensors=True)
+
+pipe.load_lora_weights("")
+
+# 更新调度器配置为固定方差调度
+pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small")
+```
+
+第二阶段模型需要额外的验证图像进行放大。您可以下载并使用训练图像的缩小版本。
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./dog_downsized"
+snapshot_download(
+ "diffusers/dog-example-downsized",
+ local_dir=local_dir,
+ repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+以下代码示例简要概述了如何结合 DreamBooth 和 LoRA 训练 DeepFloyd IF 模型。一些需要注意的重要参数包括:
+
+* `--resolution=64`,需要更小的分辨率,因为 DeepFloyd IF 是
+一个像素扩散模型,用于处理未压缩的像素,输入图像必须更小
+* `--pre_compute_text_embeddings`,提前计算文本嵌入以节省内存,因为 [`~transformers.T5Model`] 可能占用大量内存
+* `--tokenizer_max_length=77`,您可以使用更长的默认文本长度与 T5 作为文本编码器,但默认模型编码过程使用较短的文本长度
+* `--text_encoder_use_attention_mask`,将注意力掩码传递给文本编码器
+
+
+
+
+使用 LoRA 和 DreamBooth 训练 DeepFloyd IF 的第 1 阶段需要约 28GB 内存。
+
+```bash
+export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_lora"
+
+accelerate launch train_dreambooth_lora.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=64 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=5e-6 \
+ --scale_lr \
+ --max_train_steps=1200 \
+ --validation_prompt="a sks dog" \
+ --validation_epochs=25 \
+ --checkpointing_steps=100 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask
+```
+
+
+
+
+对于使用 LoRA 和 DreamBooth 的 DeepFloyd IF 第 2 阶段,请注意这些参数:
+
+* `--validation_images`,验证期间用于上采样的图像
+* `--class_labels_conditioning=timesteps`,根据需要额外条件化 UNet,如第 2 阶段中所需
+* `--learning_rate=1e-6`,与第 1 阶段相比使用较低的学习率
+* `--resolution=256`,上采样器的预期分辨率
+
+```bash
+export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_upscale"
+export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"
+
+python train_dreambooth_lora.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=256 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-6 \
+ --max_train_steps=2000 \
+ --validation_prompt="a sks dog" \
+ --validation_epochs=100 \
+ --checkpointing_steps=500 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask \
+ --validation_images $VALIDATION_IMAGES \
+ --class_labels_conditioning=timesteps
+```
+
+
+
+
+对于使用 DreamBooth 的 DeepFloyd IF 第 1 阶段,请注意这些参数:
+
+* `--skip_save_text_encoder`,跳过保存完整 T5 文本编码器与微调模型
+* `--use_8bit_adam`,使用 8 位 Adam 优化器以节省内存,因为
+
+优化器状态的大小在训练完整模型时
+* `--learning_rate=1e-7`,对于完整模型训练应使用非常低的学习率,否则模型质量会下降(您可以使用更高的学习率和更大的批次大小)
+
+使用8位Adam和批次大小为4进行训练,完整模型可以在约48GB内存下训练。
+
+```bash
+export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_if"
+
+accelerate launch train_dreambooth.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=64 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=1 \
+ --learning_rate=1e-7 \
+ --max_train_steps=150 \
+ --validation_prompt "a photo of sks dog" \
+ --validation_steps 25 \
+ --text_encoder_use_attention_mask \
+ --tokenizer_max_length 77 \
+ --pre_compute_text_embeddings \
+ --use_8bit_adam \
+ --set_grads_to_none \
+ --skip_save_text_encoder \
+ --push_to_hub
+```
+
+
+
+
+对于DeepFloyd IF的第二阶段DreamBooth,请注意这些参数:
+
+* `--learning_rate=5e-6`,使用较低的学习率和较小的有效批次大小
+* `--resolution=256`,上采样器的预期分辨率
+* `--train_batch_size=2` 和 `--gradient_accumulation_steps=6`,为了有效训练包含面部的图像,需要更大的批次大小
+
+```bash
+export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
+export INSTANCE_DIR="dog"
+export OUTPUT_DIR="dreambooth_dog_upscale"
+export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png"
+
+accelerate launch train_dreambooth.py \
+ --report_to wandb \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --instance_data_dir=$INSTANCE_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --instance_prompt="a sks dog" \
+ --resolution=256 \
+ --train_batch_size=2 \
+ --gradient_accumulation_steps=6 \
+ --learning_rate=5e-6 \
+ --max_train_steps=2000 \
+ --validation_prompt="a sks dog" \
+ --validation_steps=150 \
+ --checkpointing_steps=500 \
+ --pre_compute_text_embeddings \
+ --tokenizer_max_length=77 \
+ --text_encoder_use_attention_mask \
+ --validation_images $VALIDATION_IMAGES \
+ --class_labels_conditioning timesteps \
+ --push_to_hub
+```
+
+
+
+
+### 训练技巧
+
+训练DeepFloyd IF模型可能具有挑战性,但以下是我们发现有用的技巧:
+
+- LoRA对于训练第一阶段模型已足够,因为模型的低分辨率使得表示更精细的细节变得困难,无论如何。
+- 对于常见或简单的对象,您不一定需要微调上采样器。确保传递给上采样器的提示被调整以移除实例提示中的新令牌。例如,如果您第一阶段提示是"a sks dog",那么您第二阶段的提示应该是"a dog"。
+- 对于更精细的细节,如面部,完全训练
+使用阶段2上采样器比使用LoRA训练阶段2模型更好。使用更大的批次大小和较低的学习率也有帮助。
+- 应使用较低的学习率来训练阶段2模型。
+- [`DDPMScheduler`] 比训练脚本中使用的DPMSolver效果更好。
+
+## 下一步
+
+恭喜您训练了您的DreamBooth模型!要了解更多关于如何使用您的新模型的信息,以下指南可能有所帮助:
+- 如果您使用LoRA训练了您的模型,请学习如何[加载DreamBooth](../using-diffusers/loading_adapters)模型进行推理。
\ No newline at end of file
diff --git a/docs/source/zh/training/instructpix2pix.md b/docs/source/zh/training/instructpix2pix.md
new file mode 100644
index 000000000000..1f9f4eb21ec3
--- /dev/null
+++ b/docs/source/zh/training/instructpix2pix.md
@@ -0,0 +1,246 @@
+
+
+# InstructPix2Pix
+
+[InstructPix2Pix](https://hf.co/papers/2211.09800) 是一个基于 Stable Diffusion 训练的模型,用于根据人类提供的指令编辑图像。例如,您的提示可以是“将云变成雨天”,模型将相应编辑输入图像。该模型以文本提示(或编辑指令)和输入图像为条件。
+
+本指南将探索 [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) 训练脚本,帮助您熟悉它,以及如何将其适应您自己的用例。
+
+在运行脚本之前,请确保从源代码安装库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+然后导航到包含训练脚本的示例文件夹,并安装脚本所需的依赖项:
+
+```bash
+cd examples/instruct_pix2pix
+pip install -r requirements.txt
+```
+
+> [!TIP]
+> 🤗 Accelerate 是一个库,用于帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它将根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速导览](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
+
+初始化一个 🤗 Accelerate 环境:
+
+```bash
+accelerate config
+```
+
+要设置一个默认的 🤗 Accelerate 环境,无需选择任何配置:
+
+```bash
+accelerate config default
+```
+
+或者,如果您的环境不支持交互式 shell,例如笔记本,您可以使用:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+最后,如果您想在自己的数据集上训练模型,请查看 [创建用于训练的数据集](create_dataset) 指南,了解如何创建与训练脚本兼容的数据集。
+
+> [!TIP]
+> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读 [脚本](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py),并告诉我们如果您有任何问题或疑虑。
+
+## 脚本参数
+
+训练脚本有许多参数可帮助您自定义训练运行。所有
+参数及其描述可在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L65) 函数中找到。大多数参数都提供了默认值,这些值效果相当不错,但如果您愿意,也可以在训练命令中设置自己的值。
+
+例如,要增加输入图像的分辨率:
+
+```bash
+accelerate launch train_instruct_pix2pix.py \
+ --resolution=512 \
+```
+
+许多基本和重要的参数在 [文本到图像](text2image#script-parameters) 训练指南中已有描述,因此本指南仅关注与 InstructPix2Pix 相关的参数:
+
+- `--original_image_column`:编辑前的原始图像
+- `--edited_image_column`:编辑后的图像
+- `--edit_prompt_column`:编辑图像的指令
+- `--conditioning_dropout_prob`:训练期间编辑图像和编辑提示的 dropout 概率,这为一种或两种条件输入启用了无分类器引导(CFG)
+
+## 训练脚本
+
+数据集预处理代码和训练循环可在 [`main()`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L374) 函数中找到。这是您将修改训练脚本以适应自己用例的地方。
+
+与脚本参数类似,[文本到图像](text2image#training-script) 训练指南提供了训练脚本的逐步说明。相反,本指南将查看脚本中与 InstructPix2Pix 相关的部分。
+
+脚本首先修改 UNet 的第一个卷积层中的 [输入通道数](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L445),以适应 InstructPix2Pix 的额外条件图像:
+
+```py
+in_channels = 8
+out_channels = unet.conv_in.out_channels
+unet.register_to_config(in_channels=in_channels)
+
+with torch.no_grad():
+ new_conv_in = nn.Conv2d(
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
+ )
+ new_conv_in.weight.zero_()
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
+ unet.conv_in = new_conv_in
+```
+
+这些 UNet 参数由优化器 [更新](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L545C1-L551C6):
+
+```py
+optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+接下来,编辑后的图像和编辑指令被 [预处理](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624)并被[tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24)。重要的是,对原始图像和编辑后的图像应用相同的图像变换。
+
+```py
+def preprocess_train(examples):
+ preprocessed_images = preprocess_images(examples)
+
+ original_images, edited_images = preprocessed_images.chunk(2)
+ original_images = original_images.reshape(-1, 3, args.resolution, args.resolution)
+ edited_images = edited_images.reshape(-1, 3, args.resolution, args.resolution)
+
+ examples["original_pixel_values"] = original_images
+ examples["edited_pixel_values"] = edited_images
+
+ captions = list(examples[edit_prompt_column])
+ examples["input_ids"] = tokenize_captions(captions)
+ return examples
+```
+
+最后,在[训练循环](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L730)中,它首先将编辑后的图像编码到潜在空间:
+
+```py
+latents = vae.encode(batch["edited_pixel_values"].to(weight_dtype)).latent_dist.sample()
+latents = latents * vae.config.scaling_factor
+```
+
+然后,脚本对原始图像和编辑指令嵌入应用 dropout 以支持 CFG(Classifier-Free Guidance)。这使得模型能够调节编辑指令和原始图像对编辑后图像的影响。
+
+```py
+encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
+
+if args.conditioning_dropout_prob is not None:
+ random_p = torch.rand(bsz, device=latents.device, generator=generator)
+ prompt_mask = random_p < 2 * args.conditioning_dropout_prob
+ prompt_mask = prompt_mask.reshape(bsz, 1, 1)
+ null_conditioning = text_encoder(tokenize_captions([""]).to(accelerator.device))[0]
+ encoder_hidden_states = torch.where(prompt_mask, null_conditioning, encoder_hidden_states)
+
+ image_mask_dtype = original_image_embeds.dtype
+ image_mask = 1 - (
+ (random_p >= args.conditioning_dropout_prob).to(image_mask_dtype)
+ * (random_p < 3 * args.conditioning_dropout_prob).to(image_mask_dtype)
+ )
+ image_mask = image_mask.reshape(bsz, 1, 1, 1)
+ original_image_embeds = image_mask * original_image_embeds
+```
+
+差不多就是这样了!除了这里描述的不同之处,脚本的其余部分与[文本到图像](text2image#training-script)训练脚本非常相似,所以请随意查看以获取更多细节。如果您想了解更多关于训练循环如何工作的信息,请查看[理解管道、模型和调度器](../using-diffusers/write_own_pipeline)教程,该教程分解了去噪过程的基本模式。
+
+## 启动脚本
+
+一旦您对脚本的更改感到满意,或者如果您对默认配置没问题,您
+准备好启动训练脚本!🚀
+
+本指南使用 [fusing/instructpix2pix-1000-samples](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) 数据集,这是 [原始数据集](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) 的一个较小版本。您也可以创建并使用自己的数据集(请参阅 [创建用于训练的数据集](create_dataset) 指南)。
+
+将 `MODEL_NAME` 环境变量设置为模型名称(可以是 Hub 上的模型 ID 或本地模型的路径),并将 `DATASET_ID` 设置为 Hub 上数据集的名称。脚本会创建并保存所有组件(特征提取器、调度器、文本编码器、UNet 等)到您的仓库中的一个子文件夹。
+
+> [!TIP]
+> 为了获得更好的结果,尝试使用更大的数据集进行更长时间的训练。我们只在较小规模的数据集上测试过此训练脚本。
+>
+>
+>
+> 要使用 Weights and Biases 监控训练进度,请将 `--report_to=wandb` 参数添加到训练命令中,并使用 `--val_image_url` 指定验证图像,使用 `--validation_prompt` 指定验证提示。这对于调试模型非常有用。
+
+如果您在多个 GPU 上训练,请将 `--multi_gpu` 参数添加到 `accelerate launch` 命令中。
+
+```bash
+accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_ID \
+ --enable_xformers_memory_efficient_attention \
+ --resolution=256 \
+ --random_flip \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --checkpointing_steps=5000 \
+ --checkpoints_total_limit=1 \
+ --learning_rate=5e-05 \
+ --max_grad_norm=1 \
+ --lr_warmup_steps=0 \
+ --conditioning_dropout_prob=0.05 \
+ --mixed_precision=fp16 \
+ --seed=42 \
+ --push_to_hub
+```
+
+训练完成后,您可以使用您的新 InstructPix2Pix 进行推理:
+
+```py
+import PIL
+import requests
+import torch
+from diffusers import StableDiffusionInstructPix2PixPipeline
+from diffusers.utils import load_image
+
+pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained("your_cool_model", torch_dtype=torch.float16).to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+
+image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/test_pix2pix_4.png")
+prompt = "add some ducks to the lake"
+num_inference_steps = 20
+image_guidance_scale = 1.5
+guidance_scale = 10
+
+edited_image = pipeline(
+ prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ image_guidance_scale=image_guidance_scale,
+ guidance_scale=guidance_scale,
+ generator=generator,
+).images[0]
+edited_image.save("edited_image.png")
+```
+
+您应该尝试不同的 `num_inference_steps`、`image_guidance_scale` 和 `guidance_scale` 值,以查看它们如何影响推理速度和质量。指导比例参数
+这些参数尤其重要,因为它们控制原始图像和编辑指令对编辑后图像的影响程度。
+
+## Stable Diffusion XL
+
+Stable Diffusion XL (SDXL) 是一个强大的文本到图像模型,能够生成高分辨率图像,并在其架构中添加了第二个文本编码器。使用 [`train_instruct_pix2pix_sdxl.py`](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py) 脚本来训练 SDXL 模型以遵循图像编辑指令。
+
+SDXL 训练脚本在 [SDXL 训练](sdxl) 指南中有更详细的讨论。
+
+## 后续步骤
+
+恭喜您训练了自己的 InstructPix2Pix 模型!🥳 要了解更多关于该模型的信息,可能有助于:
+
+- 阅读 [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) 博客文章,了解更多我们使用 InstructPix2Pix 进行的一些实验、数据集准备以及不同指令的结果。
\ No newline at end of file
diff --git a/docs/source/zh/training/kandinsky.md b/docs/source/zh/training/kandinsky.md
new file mode 100644
index 000000000000..8ef3524ee7c4
--- /dev/null
+++ b/docs/source/zh/training/kandinsky.md
@@ -0,0 +1,313 @@
+
+
+# Kandinsky 2.2
+
+> [!WARNING]
+> 此脚本是实验性的,容易过拟合并遇到灾难性遗忘等问题。尝试探索不同的超参数以在您的数据集上获得最佳结果。
+
+Kandinsky 2.2 是一个多语言文本到图像模型,能够生成更逼真的图像。该模型包括一个图像先验模型,用于从文本提示创建图像嵌入,以及一个解码器模型,基于先验模型的嵌入生成图像。这就是为什么在 Diffusers 中您会找到两个独立的脚本用于 Kandinsky 2.2,一个用于训练先验模型,另一个用于训练解码器模型。您可以分别训练这两个模型,但为了获得最佳结果,您应该同时训练先验和解码器模型。
+
+根据您的 GPU,您可能需要启用 `gradient_checkpointing`(⚠️ 不支持先验模型!)、`mixed_precision` 和 `gradient_accumulation_steps` 来帮助将模型装入内存并加速训练。您可以通过启用 [xFormers](../optimization/xformers) 的内存高效注意力来进一步减少内存使用(版本 [v0.0.16](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212) 在某些 GPU 上训练时失败,因此您可能需要安装开发版本)。
+
+本指南探讨了 [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py) 和 [train_text_to_image_decoder.py](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py) 脚本,以帮助您更熟悉它,以及如何根据您的用例进行调整。
+
+在运行脚本之前,请确保从源代码安装库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+然后导航到包含训练脚本的示例文件夹,并安装脚本所需的依赖项:
+
+```bash
+cd examples/kandinsky2_2/text_to_image
+pip install -r requirements.txt
+```
+
+> [!TIP]
+> 🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate 的 [快速入门](https://huggingface.co/docs/accelerate/quicktour
+> ) 了解更多。
+
+初始化一个 🤗 Accelerate 环境:
+
+```bash
+accelerate config
+```
+
+要设置一个默认的 🤗 Accelerate 环境而不选择任何配置:
+
+```bash
+accelerate config default
+```
+
+或者,如果您的环境不支持交互式 shell,比如 notebook,您可以使用:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+最后,如果您想在自己的数据集上训练模型,请查看 [创建用于训练的数据集](create_dataset) 指南,了解如何创建与训练脚本兼容的数据集。
+
+> [!TIP]
+> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读脚本,并让我们知道您有任何疑问或顾虑。
+
+## 脚本参数
+
+训练脚本提供了许多参数来帮助您自定义训练运行。所有参数及其描述都可以在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L190) 函数中找到。训练脚本为每个参数提供了默认值,例如训练批次大小和学习率,但如果您愿意,也可以在训练命令中设置自己的值。
+
+例如,要使用 fp16 格式的混合精度加速训练,请在训练命令中添加 `--mixed_precision` 参数:
+
+```bash
+accelerate launch train_text_to_image_prior.py \
+ --mixed_precision="fp16"
+```
+
+大多数参数与 [文本到图像](text2image#script-parameters) 训练指南中的参数相同,所以让我们直接进入 Kandinsky 训练脚本的 walkthrough!
+
+### Min-SNR 加权
+
+[Min-SNR](https://huggingface.co/papers/2303.09556) 加权策略可以通过重新平衡损失来帮助训练,实现更快的收敛。训练脚本支持预测 `epsilon`(噪声)或 `v_prediction`,但 Min-SNR 与两种预测类型都兼容。此加权策略仅由 PyTorch 支持,在 Flax 训练脚本中不可用。
+
+添加 `--snr_gamma` 参数并将其设置为推荐值 5.0:
+
+```bash
+accelerate launch train_text_to_image_prior.py \
+ --snr_gamma=5.0
+```
+
+## 训练脚本
+
+训练脚本也类似于 [文本到图像](text2image#training-script) 训练指南,但已修改以支持训练 prior 和 decoder 模型。本指南重点介绍 Kandinsky 2.2 训练脚本中独特的代码。
+
+
+
+
+[`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L441) 函数包含代码 f
+或准备数据集和训练模型。
+
+您会立即注意到的主要区别之一是,训练脚本除了调度器和分词器外,还加载了一个 [`~transformers.CLIPImageProcessor`] 用于预处理图像,以及一个 [`~transformers.CLIPVisionModelWithProjection`] 模型用于编码图像:
+
+```py
+noise_scheduler = DDPMScheduler(beta_schedule="squaredcos_cap_v2", prediction_type="sample")
+image_processor = CLIPImageProcessor.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_processor"
+)
+tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
+
+with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
+ ).eval()
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
+ ).eval()
+```
+
+Kandinsky 使用一个 [`PriorTransformer`] 来生成图像嵌入,因此您需要设置优化器来学习先验模型的参数。
+
+```py
+prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+prior.train()
+optimizer = optimizer_cls(
+ prior.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+接下来,输入标题被分词,图像由 [`~transformers.CLIPImageProcessor`] [预处理](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L632):
+
+```py
+def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+```
+
+最后,[训练循环](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py#L718) 将输入图像转换为潜在表示,向图像嵌入添加噪声,并进行预测:
+
+```py
+model_pred = prior(
+ noisy_latents,
+ timestep=timesteps,
+ proj_embedding=prompt_embeds,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+).predicted_image_embedding
+```
+
+如果您想了解更多关于训练循环的工作原理,请查看 [理解管道、模型和调度器](../using-diffusers/write_own_pipeline) 教程,该教程分解了去噪过程的基本模式。
+
+
+
+
+The [`main()`](https://github.com/huggingface/di
+ffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L440) 函数包含准备数据集和训练模型的代码。
+
+与之前的模型不同,解码器初始化一个 [`VQModel`] 来将潜在变量解码为图像,并使用一个 [`UNet2DConditionModel`]:
+
+```py
+with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ vae = VQModel.from_pretrained(
+ args.pretrained_decoder_model_name_or_path, subfolder="movq", torch_dtype=weight_dtype
+ ).eval()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ args.pretrained_prior_model_name_or_path, subfolder="image_encoder", torch_dtype=weight_dtype
+ ).eval()
+unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
+```
+
+接下来,脚本包括几个图像变换和一个用于对图像应用变换并返回像素值的[预处理](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L622)函数:
+
+```py
+def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["clip_pixel_values"] = image_processor(images, return_tensors="pt").pixel_values
+ return examples
+```
+
+最后,[训练循环](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py#L706)处理将图像转换为潜在变量、添加噪声和预测噪声残差。
+
+如果您想了解更多关于训练循环如何工作的信息,请查看[理解管道、模型和调度器](../using-diffusers/write_own_pipeline)教程,该教程分解了去噪过程的基本模式。
+
+```py
+model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_kwargs).sample[:, :4]
+```
+
+
+
+
+## 启动脚本
+
+一旦您完成了所有更改或接受默认配置,就可以启动训练脚本了!🚀
+
+您将在[Naruto BLIP 字幕](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions)数据集上进行训练,以生成您自己的Naruto角色,但您也可以通过遵循[创建用于训练的数据集](create_dataset)指南来创建和训练您自己的数据集。将环境变量 `DATASET_NAME` 设置为Hub上数据集的名称,或者如果您在自己的文件上训练,将环境变量 `TRAIN_DIR` 设置为数据集的路径。
+
+如果您在多个GPU上训练,请在 `accelerate launch` 命令中添加 `--multi_gpu` 参数。
+
+> [!TIP]
+> 要使用Weights & Biases监控训练进度,请在训练命令中添加 `--report_to=wandb` 参数。您还需要
+> 建议在训练命令中添加 `--validation_prompt` 以跟踪结果。这对于调试模型和查看中间结果非常有用。
+
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --validation_prompts="A robot naruto, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-prior-naruto-model"
+```
+
+
+
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --validation_prompts="A robot naruto, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="kandi2-decoder-naruto-model"
+```
+
+
+
+
+训练完成后,您可以使用新训练的模型进行推理!
+
+
+
+
+```py
+from diffusers import AutoPipelineForText2Image, DiffusionPipeline
+import torch
+
+prior_pipeline = DiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16)
+prior_components = {"prior_" + k: v for k,v in prior_pipeline.components.items()}
+pipeline = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", **prior_components, torch_dtype=torch.float16)
+
+pipe.enable_model_cpu_offload()
+prompt="A robot naruto, 4k photo"
+image = pipeline(prompt=prompt, negative_prompt=negative_prompt).images[0]
+```
+
+> [!TIP]
+> 可以随意将 `kandinsky-community/kandinsky-2-2-decoder` 替换为您自己训练的 decoder 检查点!
+
+
+
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+
+prompt="A robot naruto, 4k photo"
+image = pipeline(prompt=prompt).images[0]
+```
+
+对于 decoder 模型,您还可以从保存的检查点进行推理,这对于查看中间结果很有用。在这种情况下,将检查点加载到 UNet 中:
+
+```py
+from diffusers import AutoPipelineForText2Image, UNet2DConditionModel
+
+unet = UNet2DConditionModel.from_pretrained("path/to/saved/model" + "/checkpoint-/unet")
+
+pipeline = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", unet=unet, torch_dtype=torch.float16)
+pipeline.enable_model_cpu_offload()
+
+image = pipeline(prompt="A robot naruto, 4k photo").images[0]
+```
+
+
+
+
+## 后续步骤
+
+恭喜您训练了一个 Kandinsky 2.2 模型!要了解更多关于如何使用您的新模型的信息,以下指南可能会有所帮助:
+
+- 阅读 [Kandinsky](../using-diffusers/kandinsky) 指南,学习如何将其用于各种不同的任务(文本到图像、图像到图像、修复、插值),以及如何与 ControlNet 结合使用。
+- 查看 [DreamBooth](dreambooth) 和 [LoRA](lora) 训练指南,学习如何使用少量示例图像训练个性化的 Kandinsky 模型。这两种训练技术甚至可以结合使用!
\ No newline at end of file
diff --git a/docs/source/zh/training/lora.md b/docs/source/zh/training/lora.md
new file mode 100644
index 000000000000..ce29365450bd
--- /dev/null
+++ b/docs/source/zh/training/lora.md
@@ -0,0 +1,216 @@
+
+
+# LoRA 低秩适配
+
+> [!WARNING]
+> 当前功能处于实验阶段,API可能在未来版本中变更。
+
+[LoRA(大语言模型的低秩适配)](https://hf.co/papers/2106.09685) 是一种轻量级训练技术,能显著减少可训练参数量。其原理是通过向模型注入少量新权重参数,仅训练这些新增参数。这使得LoRA训练速度更快、内存效率更高,并生成更小的模型权重文件(通常仅数百MB),便于存储和分享。LoRA还可与DreamBooth等其他训练技术结合以加速训练过程。
+
+> [!TIP]
+> LoRA具有高度通用性,目前已支持以下应用场景:[DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)、[Kandinsky 2.2](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py)、[Stable Diffusion XL](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py)、[文生图](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)以及[Wuerstchen](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py)。
+
+本指南将通过解析[train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)脚本,帮助您深入理解其工作原理,并掌握如何针对具体需求进行定制化修改。
+
+运行脚本前,请确保从源码安装库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+进入包含训练脚本的示例目录,并安装所需依赖:
+
+
+
+
+```bash
+cd examples/text_to_image
+pip install -r requirements.txt
+```
+
+
+
+
+```bash
+cd examples/text_to_image
+pip install -r requirements_flax.txt
+```
+
+
+
+
+> [!TIP]
+> 🤗 Accelerate是一个支持多GPU/TPU训练和混合精度计算的库,它能根据硬件环境自动配置训练方案。参阅🤗 Accelerate[快速入门](https://huggingface.co/docs/accelerate/quicktour)了解更多。
+
+初始化🤗 Accelerate环境:
+
+```bash
+accelerate config
+```
+
+若要创建默认配置环境(不进行交互式设置):
+
+```bash
+accelerate config default
+```
+
+若在非交互环境(如Jupyter notebook)中使用:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+如需训练自定义数据集,请参考[创建训练数据集指南](create_dataset)了解数据准备流程。
+
+> [!TIP]
+> 以下章节重点解析训练脚本中与LoRA相关的核心部分,但不会涵盖所有实现细节。如需完整理解,建议直接阅读[脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py),如有疑问欢迎反馈。
+
+## 脚本参数
+
+训练脚本提供众多参数用于定制训练过程。所有参数及其说明均定义在[`parse_args()`](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L85)函数中。多数参数设有默认值,您也可以通过命令行参数覆盖:
+
+例如增加训练轮次:
+
+```bash
+accelerate launch train_text_to_image_lora.py \
+ --num_train_epochs=150 \
+```
+
+基础参数说明可参考[文生图训练指南](text2image#script-parameters),此处重点介绍LoRA相关参数:
+
+- `--rank`:低秩矩阵的内部维度,数值越高可训练参数越多
+- `--learning_rate`:默认学习率为1e-4,但使用LoRA时可适当提高
+
+## 训练脚本实现
+
+数据集预处理和训练循环逻辑位于[`main()`](https://github.com/huggingface/diffusers/blob/dd9a5caf61f04d11c0fa9f3947b69ab0010c9a0f/examples/text_to_image/train_text_to_image_lora.py#L371)函数,如需定制训练流程,可在此处进行修改。
+
+与参数说明类似,训练流程的完整解析请参考[文生图指南](text2image#training-script),下文重点介绍LoRA相关实现。
+
+
+
+
+Diffusers使用[PEFT](https://hf.co/docs/peft)库的[`~peft.LoraConfig`]配置LoRA适配器参数,包括秩(rank)、alpha值以及目标模块。适配器被注入UNet后,通过`lora_layers`筛选出需要优化的LoRA层。
+
+```py
+unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+)
+
+unet.add_adapter(unet_lora_config)
+lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
+```
+
+
+
+
+当需要微调文本编码器时(如SDXL模型),Diffusers同样支持通过[PEFT](https://hf.co/docs/peft)库实现。[`~peft.LoraConfig`]配置适配器参数后注入文本编码器,并筛选LoRA层进行训练。
+
+```py
+text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+)
+
+text_encoder_one.add_adapter(text_lora_config)
+text_encoder_two.add_adapter(text_lora_config)
+text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+```
+
+
+
+
+[优化器](https://github.com/huggingface/diffusers/blob/e4b8f173b97731686e290b2eb98e7f5df2b1b322/examples/text_to_image/train_text_to_image_lora.py#L529)仅对`lora_layers`参数进行优化:
+
+```py
+optimizer = optimizer_cls(
+ lora_layers,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+除LoRA层设置外,该训练脚本与标准train_text_to_image.py基本相同!
+
+## 启动训练
+
+完成所有配置后,即可启动训练脚本!🚀
+
+以下示例使用[Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions)训练生成火影角色。请设置环境变量`MODEL_NAME`和`DATASET_NAME`指定基础模型和数据集,`OUTPUT_DIR`设置输出目录,`HUB_MODEL_ID`指定Hub存储库名称。脚本运行后将生成以下文件:
+
+- 模型检查点
+- `pytorch_lora_weights.safetensors`(训练好的LoRA权重)
+
+多GPU训练请添加`--multi_gpu`参数。
+
+> [!WARNING]
+> 在11GB显存的2080 Ti显卡上完整训练约需5小时。
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export OUTPUT_DIR="/sddata/finetune/lora/naruto"
+export HUB_MODEL_ID="naruto-lora"
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --dataloader_num_workers=8 \
+ --resolution=512 \
+ --center_crop \
+ --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-04 \
+ --max_grad_norm=1 \
+ --lr_scheduler="cosine" \
+ --lr_warmup_steps=0 \
+ --output_dir=${OUTPUT_DIR} \
+ --push_to_hub \
+ --hub_model_id=${HUB_MODEL_ID} \
+ --report_to=wandb \
+ --checkpointing_steps=500 \
+ --validation_prompt="蓝色眼睛的火影忍者角色" \
+ --seed=1337
+```
+
+训练完成后,您可以通过以下方式进行推理:
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+
+pipeline = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+pipeline.load_lora_weights("path/to/lora/model", weight_name="pytorch_lora_weights.safetensors")
+image = pipeline("A naruto with blue eyes").images[0]
+```
+
+## 后续步骤
+
+恭喜完成LoRA模型训练!如需进一步了解模型使用方法,可参考以下指南:
+
+- 学习如何加载[不同格式的LoRA权重](../using-diffusers/loading_adapters#LoRA)(如Kohya或TheLastBen训练的模型)
+- 掌握使用PEFT进行[多LoRA组合推理](../tutorials/using_peft_for_inference)的技巧
\ No newline at end of file
diff --git a/docs/source/zh/training/overview.md b/docs/source/zh/training/overview.md
new file mode 100644
index 000000000000..ebf814aefe44
--- /dev/null
+++ b/docs/source/zh/training/overview.md
@@ -0,0 +1,60 @@
+
+
+# 概述
+
+🤗 Diffusers 提供了一系列训练脚本供您训练自己的diffusion模型。您可以在 [diffusers/examples](https://github.com/huggingface/diffusers/tree/main/examples) 找到所有训练脚本。
+
+每个训练脚本具有以下特点:
+
+- **独立完整**:训练脚本不依赖任何本地文件,所有运行所需的包都通过 `requirements.txt` 文件安装
+- **易于调整**:这些脚本是针对特定任务的训练示例,并不能开箱即用地适用于所有训练场景。您可能需要根据具体用例调整脚本。为此,我们完全公开了数据预处理代码和训练循环,方便您进行修改
+- **新手友好**:脚本设计注重易懂性和入门友好性,而非包含最新最优方法以获得最具竞争力的结果。我们有意省略了过于复杂的训练方法
+- **单一用途**:每个脚本仅针对一个任务设计,确保代码可读性和可理解性
+
+当前提供的训练脚本包括:
+
+| 训练类型 | 支持SDXL | 支持LoRA | 支持Flax |
+|---|---|---|---|
+| [unconditional image generation](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | | | |
+| [text-to-image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) | 👍 | 👍 | 👍 |
+| [textual inversion](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | | | 👍 |
+| [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) | 👍 | 👍 | 👍 |
+| [ControlNet](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) | 👍 | | 👍 |
+| [InstructPix2Pix](https://github.com/huggingface/diffusers/tree/main/examples/instruct_pix2pix) | 👍 | | |
+| [Custom Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion) | | | |
+| [T2I-Adapters](https://github.com/huggingface/diffusers/tree/main/examples/t2i_adapter) | 👍 | | |
+| [Kandinsky 2.2](https://github.com/huggingface/diffusers/tree/main/examples/kandinsky2_2/text_to_image) | | 👍 | |
+| [Wuerstchen](https://github.com/huggingface/diffusers/tree/main/examples/wuerstchen/text_to_image) | | 👍 | |
+
+这些示例处于**积极维护**状态,如果遇到问题请随时提交issue。如果您认为应该添加其他训练示例,欢迎创建[功能请求](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=)与我们讨论,我们将评估其是否符合独立完整、易于调整、新手友好和单一用途的标准。
+
+## 安装
+
+请按照以下步骤在新虚拟环境中从源码安装库,确保能成功运行最新版本的示例脚本:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+然后进入具体训练脚本目录(例如[DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)),安装对应的`requirements.txt`文件。部分脚本针对SDXL、LoRA或Flax有特定要求文件,使用时请确保安装对应文件。
+
+```bash
+cd examples/dreambooth
+pip install -r requirements.txt
+# 如需用DreamBooth训练SDXL
+pip install -r requirements_sdxl.txt
+```
+
+为加速训练并降低内存消耗,我们建议:
+
+- 使用PyTorch 2.0或更高版本,自动启用[缩放点积注意力](../optimization/fp16#scaled-dot-product-attention)(无需修改训练代码)
+- 安装[xFormers](../optimization/xformers)以启用内存高效注意力机制
\ No newline at end of file
diff --git a/docs/source/zh/training/text2image.md b/docs/source/zh/training/text2image.md
new file mode 100644
index 000000000000..4465adbe2ad7
--- /dev/null
+++ b/docs/source/zh/training/text2image.md
@@ -0,0 +1,260 @@
+
+
+# 文生图
+
+> [!WARNING]
+> 文生图训练脚本目前处于实验阶段,容易出现过拟合和灾难性遗忘等问题。建议尝试不同超参数以获得最佳数据集适配效果。
+
+Stable Diffusion 等文生图模型能够根据文本提示生成对应图像。
+
+模型训练对硬件要求较高,但启用 `gradient_checkpointing` 和 `mixed_precision` 后,可在单块24GB显存GPU上完成训练。如需更大批次或更快训练速度,建议使用30GB以上显存的GPU设备。通过启用 [xFormers](../optimization/xformers) 内存高效注意力机制可降低显存占用。JAX/Flax 训练方案也支持TPU/GPU高效训练,但不支持梯度检查点、梯度累积和xFormers。使用Flax训练时建议配备30GB以上显存GPU或TPU v3。
+
+本指南将详解 [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) 训练脚本,助您掌握其原理并适配自定义需求。
+
+运行脚本前请确保已从源码安装库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+然后进入包含训练脚本的示例目录,安装对应依赖:
+
+
+
+```bash
+cd examples/text_to_image
+pip install -r requirements.txt
+```
+
+
+```bash
+cd examples/text_to_image
+pip install -r requirements_flax.txt
+```
+
+
+
+> [!TIP]
+> 🤗 Accelerate 是支持多GPU/TPU训练和混合精度的工具库,能根据硬件环境自动配置训练参数。参阅 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 了解更多。
+
+初始化 🤗 Accelerate 环境:
+
+```bash
+accelerate config
+```
+
+要创建默认配置环境(不进行交互式选择):
+
+```bash
+accelerate config default
+```
+
+若环境不支持交互式shell(如notebook),可使用:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+最后,如需在自定义数据集上训练,请参阅 [创建训练数据集](create_dataset) 指南了解如何准备适配脚本的数据集。
+
+## 脚本参数
+
+> [!TIP]
+> 以下重点介绍脚本中影响训练效果的关键参数,如需完整参数说明可查阅 [脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)。如有疑问欢迎反馈。
+
+训练脚本提供丰富参数供自定义训练流程,所有参数及说明详见 [`parse_args()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L193) 函数。该函数为每个参数提供默认值(如批次大小、学习率等),也可通过命令行参数覆盖。
+
+例如使用fp16混合精度加速训练:
+
+```bash
+accelerate launch train_text_to_image.py \
+ --mixed_precision="fp16"
+```
+
+基础重要参数包括:
+
+- `--pretrained_model_name_or_path`: Hub模型名称或本地预训练模型路径
+- `--dataset_name`: Hub数据集名称或本地训练数据集路径
+- `--image_column`: 数据集中图像列名
+- `--caption_column`: 数据集中文本列名
+- `--output_dir`: 模型保存路径
+- `--push_to_hub`: 是否将训练模型推送至Hub
+- `--checkpointing_steps`: 模型检查点保存步数;训练中断时可添加 `--resume_from_checkpoint` 从该检查点恢复训练
+
+### Min-SNR加权策略
+
+[Min-SNR](https://huggingface.co/papers/2303.09556) 加权策略通过重新平衡损失函数加速模型收敛。训练脚本支持预测 `epsilon`(噪声)或 `v_prediction`,而Min-SNR兼容两种预测类型。该策略仅限PyTorch版本,Flax训练脚本不支持。
+
+添加 `--snr_gamma` 参数并设为推荐值5.0:
+
+```bash
+accelerate launch train_text_to_image.py \
+ --snr_gamma=5.0
+```
+
+可通过此 [Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) 报告比较不同 `snr_gamma` 值的损失曲面。小数据集上Min-SNR效果可能不如大数据集显著。
+
+## 训练脚本解析
+
+数据集预处理代码和训练循环位于 [`main()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L490) 函数,自定义修改需在此处进行。
+
+`train_text_to_image` 脚本首先 [加载调度器](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L543) 和分词器,此处可替换其他调度器:
+
+```py
+noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+)
+```
+
+接着 [加载UNet模型](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L619):
+
+```py
+load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+model.register_to_config(**load_model.config)
+
+model.load_state_dict(load_model.state_dict())
+```
+
+随后对数据集的文本和图像列进行预处理。[`tokenize_captions`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L724) 函数处理文本分词,[`train_transforms`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L742) 定义图像增强策略,二者集成于 `preprocess_train`:
+
+```py
+def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+```
+
+最后,[训练循环](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L878) 处理剩余流程:图像编码为潜空间、添加噪声、计算文本嵌入条件、更新模型参数、保存并推送模型至Hub。想深入了解训练循环原理,可参阅 [理解管道、模型与调度器](../using-diffusers/write_own_pipeline) 教程,该教程解析了去噪过程的核心逻辑。
+
+## 启动脚本
+
+完成所有配置后,即可启动训练脚本!🚀
+
+
+
+
+以 [火影忍者BLIP标注数据集](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) 为例训练生成火影角色。设置环境变量 `MODEL_NAME` 和 `dataset_name` 指定模型和数据集(Hub或本地路径)。多GPU训练需在 `accelerate launch` 命令中添加 `--multi_gpu` 参数。
+
+> [!TIP]
+> 使用本地数据集时,设置 `TRAIN_DIR` 和 `OUTPUT_DIR` 环境变量为数据集路径和模型保存路径。
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export dataset_name="lambdalabs/naruto-blip-captions"
+
+accelerate launch --mixed_precision="fp16" train_text_to_image.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --use_ema \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --enable_xformers_memory_efficient_attention \
+ --lr_scheduler="constant" --lr_warmup_steps=0 \
+ --output_dir="sd-naruto-model" \
+ --push_to_hub
+```
+
+
+
+
+Flax训练方案在TPU/GPU上效率更高(由 [@duongna211](https://github.com/duongna21) 实现),TPU性能更优但GPU表现同样出色。
+
+设置环境变量 `MODEL_NAME` 和 `dataset_name` 指定模型和数据集(Hub或本地路径)。
+
+> [!TIP]
+> 使用本地数据集时,设置 `TRAIN_DIR` 和 `OUTPUT_DIR` 环境变量为数据集路径和模型保存路径。
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export dataset_name="lambdalabs/naruto-blip-captions"
+
+python train_text_to_image_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$dataset_name \
+ --resolution=512 --center_crop --random_flip \
+ --train_batch_size=1 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --output_dir="sd-naruto-model" \
+ --push_to_hub
+```
+
+
+
+
+训练完成后,即可使用新模型进行推理:
+
+
+
+
+```py
+from diffusers import StableDiffusionPipeline
+import torch
+
+pipeline = StableDiffusionPipeline.from_pretrained("path/to/saved_model", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+
+image = pipeline(prompt="yoda").images[0]
+image.save("yoda-naruto.png")
+```
+
+
+
+
+```py
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline
+
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("path/to/saved_model", dtype=jax.numpy.bfloat16)
+
+prompt = "yoda naruto"
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 50
+
+num_samples = jax.device_count()
+prompt = num_samples * [prompt]
+prompt_ids = pipeline.prepare_inputs(prompt)
+
+# 分片输入和随机数
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+image.save("yoda-naruto.png")
+```
+
+
+
+
+## 后续步骤
+
+恭喜完成文生图模型训练!如需进一步使用模型,以下指南可能有所帮助:
+
+- 了解如何加载 [LoRA权重](../using-diffusers/loading_adapters#LoRA) 进行推理(如果训练时使用了LoRA)
+- 在 [文生图](../using-diffusers/conditional_image_generation) 任务指南中,了解引导尺度等参数或提示词加权等技术如何控制生成效果
\ No newline at end of file
diff --git a/docs/source/zh/training/text_inversion.md b/docs/source/zh/training/text_inversion.md
new file mode 100644
index 000000000000..eda9f911441b
--- /dev/null
+++ b/docs/source/zh/training/text_inversion.md
@@ -0,0 +1,287 @@
+
+
+# 文本反转(Textual Inversion)
+
+[文本反转](https://hf.co/papers/2208.01618)是一种训练技术,仅需少量示例图像即可个性化图像生成模型。该技术通过学习和更新文本嵌入(新嵌入会绑定到提示中必须使用的特殊词汇)来匹配您提供的示例图像。
+
+如果在显存有限的GPU上训练,建议在训练命令中启用`gradient_checkpointing`和`mixed_precision`参数。您还可以通过[xFormers](../optimization/xformers)使用内存高效注意力机制来减少内存占用。JAX/Flax训练也支持在TPU和GPU上进行高效训练,但不支持梯度检查点或xFormers。在配置与PyTorch相同的情况下,Flax训练脚本的速度至少应快70%!
+
+本指南将探索[textual_inversion.py](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py)脚本,帮助您更熟悉其工作原理,并了解如何根据自身需求进行调整。
+
+运行脚本前,请确保从源码安装库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+进入包含训练脚本的示例目录,并安装所需依赖:
+
+
+
+
+```bash
+cd examples/textual_inversion
+pip install -r requirements.txt
+```
+
+
+
+
+```bash
+cd examples/textual_inversion
+pip install -r requirements_flax.txt
+```
+
+
+
+
+> [!TIP]
+> 🤗 Accelerate 是一个帮助您在多GPU/TPU或混合精度环境下训练的工具库。它会根据硬件和环境自动配置训练设置。查看🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour)了解更多。
+
+初始化🤗 Accelerate环境:
+
+```bash
+accelerate config
+```
+
+要设置默认的🤗 Accelerate环境(不选择任何配置):
+
+```bash
+accelerate config default
+```
+
+如果您的环境不支持交互式shell(如notebook),可以使用:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+最后,如果想在自定义数据集上训练模型,请参阅[创建训练数据集](create_dataset)指南,了解如何创建适用于训练脚本的数据集。
+
+> [!TIP]
+> 以下部分重点介绍训练脚本中需要理解的关键修改点,但未涵盖脚本所有细节。如需深入了解,可随时查阅[脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py),如有疑问欢迎反馈。
+
+## 脚本参数
+
+训练脚本包含众多参数,便于您定制训练过程。所有参数及其说明都列在[`parse_args()`](https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L176)函数中。Diffusers为每个参数提供了默认值(如训练批次大小和学习率),但您可以通过训练命令自由调整这些值。
+
+例如,将梯度累积步数增加到默认值1以上:
+
+```bash
+accelerate launch textual_inversion.py \
+ --gradient_accumulation_steps=4
+```
+
+其他需要指定的基础重要参数包括:
+
+- `--pretrained_model_name_or_path`:Hub上的模型名称或本地预训练模型路径
+- `--train_data_dir`:包含训练数据集(示例图像)的文件夹路径
+- `--output_dir`:训练模型保存位置
+- `--push_to_hub`:是否将训练好的模型推送至Hub
+- `--checkpointing_steps`:训练过程中保存检查点的频率;若训练意外中断,可通过在命令中添加`--resume_from_checkpoint`从该检查点恢复训练
+- `--num_vectors`:学习嵌入的向量数量;增加此参数可提升模型效果,但会提高训练成本
+- `--placeholder_token`:绑定学习嵌入的特殊词汇(推理时需在提示中使用该词)
+- `--initializer_token`:大致描述训练目标的单字词汇(如物体或风格)
+- `--learnable_property`:训练目标是学习新"风格"(如梵高画风)还是"物体"(如您的宠物狗)
+
+## 训练脚本
+
+与其他训练脚本不同,textual_inversion.py包含自定义数据集类[`TextualInversionDataset`](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L487),用于创建数据集。您可以自定义图像尺寸、占位符词汇、插值方法、是否裁剪图像等。如需修改数据集创建方式,可调整`TextualInversionDataset`类。
+
+接下来,在[`main()`](https://github.com/huggingface/diffusers/blob/839c2a5ece0af4e75530cb520d77bc7ed8acf474/examples/textual_inversion/textual_inversion.py#L573)函数中可找到数据集预处理代码和训练循环。
+
+脚本首先加载[tokenizer](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L616)、[scheduler和模型](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L622):
+
+```py
+# 加载tokenizer
+if args.tokenizer_name:
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
+elif args.pretrained_model_name_or_path:
+ tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
+
+# 加载scheduler和模型
+noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+)
+vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
+unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
+)
+```
+
+随后将特殊[占位符词汇](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L632)加入tokenizer,并调整嵌入层以适配新词汇。
+
+接着,脚本通过`TextualInversionDataset`[创建数据集](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L716):
+
+```py
+train_dataset = TextualInversionDataset(
+ data_root=args.train_data_dir,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))),
+ repeats=args.repeats,
+ learnable_property=args.learnable_property,
+ center_crop=args.center_crop,
+ set="train",
+)
+train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
+)
+```
+
+最后,[训练循环](https://github.com/huggingface/diffusers/blob/b81c69e489aad3a0ba73798c459a33990dc4379c/examples/textual_inversion/textual_inversion.py#L784)处理从预测噪声残差到更新特殊占位符词汇嵌入权重的所有流程。
+
+如需深入了解训练循环工作原理,请参阅[理解管道、模型与调度器](../using-diffusers/write_own_pipeline)教程,该教程解析了去噪过程的基本模式。
+
+## 启动脚本
+
+完成所有修改或确认默认配置后,即可启动训练脚本!🚀
+
+本指南将下载[猫玩具](https://huggingface.co/datasets/diffusers/cat_toy_example)的示例图像并存储在目录中。当然,您也可以创建和使用自己的数据集(参见[创建训练数据集](create_dataset)指南)。
+
+```py
+from huggingface_hub import snapshot_download
+
+local_dir = "./cat"
+snapshot_download(
+ "diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
+)
+```
+
+设置环境变量`MODEL_NAME`为Hub上的模型ID或本地模型路径,`DATA_DIR`为刚下载的猫图像路径。脚本会将以下文件保存至您的仓库:
+
+- `learned_embeds.bin`:与示例图像对应的学习嵌入向量
+- `token_identifier.txt`:特殊占位符词汇
+- `type_of_concept.txt`:训练概念类型("object"或"style")
+
+> [!WARNING]
+> 在单块V100 GPU上完整训练约需1小时。
+
+启动脚本前还有最后一步。如果想实时观察训练过程,可以定期保存生成图像。在训练命令中添加以下参数:
+
+```bash
+--validation_prompt="A train"
+--num_validation_images=4
+--validation_steps=100
+```
+
+
+
+
+```bash
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
+export DATA_DIR="./cat"
+
+accelerate launch textual_inversion.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=4 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --output_dir="textual_inversion_cat" \
+ --push_to_hub
+```
+
+
+
+
+```bash
+export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
+export DATA_DIR="./cat"
+
+python textual_inversion_flax.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --train_data_dir=$DATA_DIR \
+ --learnable_property="object" \
+ --placeholder_token="" \
+ --initializer_token="toy" \
+ --resolution=512 \
+ --train_batch_size=1 \
+ --max_train_steps=3000 \
+ --learning_rate=5.0e-04 \
+ --scale_lr \
+ --output_dir="textual_inversion_cat" \
+ --push_to_hub
+```
+
+
+
+
+训练完成后,可以像这样使用新模型进行推理:
+
+
+
+
+```py
+from diffusers import StableDiffusionPipeline
+import torch
+
+pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
+pipeline.load_textual_inversion("sd-concepts-library/cat-toy")
+image = pipeline("A train", num_inference_steps=50).images[0]
+image.save("cat-train.png")
+```
+
+
+
+
+Flax不支持[`~loaders.TextualInversionLoaderMixin.load_textual_inversion`]方法,但textual_inversion_flax.py脚本会在训练后[保存](https://github.com/huggingface/diffusers/blob/c0f058265161178f2a88849e92b37ffdc81f1dcc/examples/textual_inversion/textual_inversion_flax.py#L636C2-L636C2)学习到的嵌入作为模型的一部分。这意味着您可以像使用其他Flax模型一样进行推理:
+
+```py
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline
+
+model_path = "path-to-your-trained-model"
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
+
+prompt = "A train"
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 50
+
+num_samples = jax.device_count()
+prompt = num_samples * [prompt]
+prompt_ids = pipeline.prepare_inputs(prompt)
+
+# 分片输入和随机数生成器
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+image.save("cat-train.png")
+```
+
+
+
+
+## 后续步骤
+
+恭喜您成功训练了自己的文本反转模型!🎉 如需了解更多使用技巧,以下指南可能会有所帮助:
+
+- 学习如何[加载文本反转嵌入](../using-diffusers/loading_adapters),并将其用作负面嵌入
+- 学习如何将[文本反转](textual_inversion_inference)应用于Stable Diffusion 1/2和Stable Diffusion XL的推理
diff --git a/docs/source/zh/training/wuerstchen.md b/docs/source/zh/training/wuerstchen.md
new file mode 100644
index 000000000000..c80cc944a3d8
--- /dev/null
+++ b/docs/source/zh/training/wuerstchen.md
@@ -0,0 +1,182 @@
+
+
+# Wuerstchen
+
+[Wuerstchen](https://hf.co/papers/2306.00637) 模型通过将潜在空间压缩 42 倍,在不影响图像质量的情况下大幅降低计算成本并加速推理。在训练过程中,Wuerstchen 使用两个模型(VQGAN + 自动编码器)来压缩潜在表示,然后第三个模型(文本条件潜在扩散模型)在这个高度压缩的空间上进行条件化以生成图像。
+
+为了将先验模型放入 GPU 内存并加速训练,尝试分别启用 `gradient_accumulation_steps`、`gradient_checkpointing` 和 `mixed_precision`。
+
+本指南探讨 [train_text_to_image_prior.py](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) 脚本,帮助您更熟悉它,以及如何根据您的用例进行适配。
+
+在运行脚本之前,请确保从源代码安装库:
+
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+然后导航到包含训练脚本的示例文件夹,并安装脚本所需的依赖项:
+
+```bash
+cd examples/wuerstchen/text_to_image
+pip install -r requirements.txt
+```
+
+> [!TIP]
+> 🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
+
+初始化一个 🤗 Accelerate 环境:
+
+```bash
+accelerate config
+```
+
+要设置一个默认的 🤗 Accelerate 环境而不选择任何配置:
+
+```bash
+accelerate config default
+```
+
+或者,如果您的环境不支持交互式 shell,例如笔记本,您可以使用:
+
+```py
+from accelerate.utils import write_basic_config
+
+write_basic_config()
+```
+
+最后,如果您想在自己的数据集上训练模型,请查看 [创建训练数据集](create_dataset) 指南,了解如何创建与训练脚本兼容的数据集。
+
+> [!TIP]
+> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未涵盖 [脚本](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) 的详细信息。如果您有兴趣了解更多,请随时阅读脚本,并告诉我们您是否有任何问题或疑虑。
+
+## 脚本参数
+
+训练脚本提供了许多参数来帮助您自定义训练运行。所有参数及其描述都可以在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L192) 函数中找到。它为每个参数提供了默认值,例如训练批次大小和学习率,但如果您愿意,也可以在训练命令中设置自己的值。
+
+例如,要使用 fp16 格式的混合精度加速训练,请在训练命令中添加 `--mixed_precision` 参数:
+
+```bash
+accelerate launch train_text_to_image_prior.py \
+ --mixed_precision="fp16"
+```
+
+大多数参数与 [文本到图像](text2image#script-parameters) 训练指南中的参数相同,因此让我们直接深入 Wuerstchen 训练脚本!
+
+## 训练脚本
+
+训练脚本也与 [文本到图像](text2image#training-script) 训练指南类似,但已修改以支持 Wuerstchen。本指南重点介绍 Wuerstchen 训练脚本中独特的代码。
+
+[`main()`](https://github.com/huggingface/diffusers/blob/6e68c71503682c8693cb5b06a4da4911dfd655ee/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L441) 函数首先初始化图像编码器 - 一个 [EfficientNet](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/modeling_efficient_net_encoder.py) - 以及通常的调度器和分词器。
+
+```py
+with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
+ state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
+ image_encoder = EfficientNetEncoder()
+ image_encoder.load_state_dict(state_dict["effnet_state_dict"])
+ image_encoder.eval()
+```
+
+您还将加载 [`WuerstchenPrior`] 模型以进行优化。
+
+```py
+prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
+
+optimizer = optimizer_cls(
+ prior.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+)
+```
+
+接下来,您将对图像应用一些 [transforms](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656) 并对标题进行 [tokenize](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L637):
+
+```py
+def preprocess_train(examples):
+ images = [image.conver
+t("RGB") for image in examples[image_column]]
+ examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
+ examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
+ return examples
+```
+
+最后,[训练循环](https://github.com/huggingface/diffusers/blob/65ef7a0c5c594b4f84092e328fbdd73183613b30/examples/wuerstchen/text_to_image/train_text_to_image_prior.py#L656)处理使用`EfficientNetEncoder`将图像压缩到潜在空间,向潜在表示添加噪声,并使用[`WuerstchenPrior`]模型预测噪声残差。
+
+```py
+pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
+```
+
+如果您想了解更多关于训练循环的工作原理,请查看[理解管道、模型和调度器](../using-diffusers/write_own_pipeline)教程,该教程分解了去噪过程的基本模式。
+
+## 启动脚本
+
+一旦您完成了所有更改或对默认配置满意,就可以启动训练脚本了!🚀
+
+设置`DATASET_NAME`环境变量为Hub中的数据集名称。本指南使用[Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions)数据集,但您也可以创建和训练自己的数据集(参见[创建用于训练的数据集](create_dataset)指南)。
+
+> [!TIP]
+> 要使用Weights & Biases监控训练进度,请在训练命令中添加`--report_to=wandb`参数。您还需要在训练命令中添加`--validation_prompt`以跟踪结果。这对于调试模型和查看中间结果非常有用。
+
+```bash
+export DATASET_NAME="lambdalabs/naruto-blip-captions"
+
+accelerate launch train_text_to_image_prior.py \
+ --mixed_precision="fp16" \
+ --dataset_name=$DATASET_NAME \
+ --resolution=768 \
+ --train_batch_size=4 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --dataloader_num_workers=4 \
+ --max_train_steps=15000 \
+ --learning_rate=1e-05 \
+ --max_grad_norm=1 \
+ --checkpoints_total_limit=3 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --validation_prompts="A robot naruto, 4k photo" \
+ --report_to="wandb" \
+ --push_to_hub \
+ --output_dir="wuerstchen-prior-naruto-model"
+```
+
+训练完成后,您可以使用新训练的模型进行推理!
+
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
+
+pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torch_dtype=torch.float16).to("cuda")
+
+caption = "A cute bird naruto holding a shield"
+images = pipeline(
+ caption,
+ width=1024,
+ height=1536,
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
+ prior_guidance_scale=4.0,
+ num_images_per_prompt=2,
+).images
+```
+
+## 下一步
+
+恭喜您训练了一个Wuerstchen模型!要了解更多关于如何使用您的新模型的信息,请参
+以下内容可能有所帮助:
+
+- 查看 [Wuerstchen](../api/pipelines/wuerstchen#text-to-image-generation) API 文档,了解更多关于如何使用该管道进行文本到图像生成及其限制的信息。
\ No newline at end of file
diff --git a/docs/source/zh/consisid.md b/docs/source/zh/using-diffusers/consisid.md
similarity index 99%
rename from docs/source/zh/consisid.md
rename to docs/source/zh/using-diffusers/consisid.md
index 2f404499fc69..018c5e706fb7 100644
--- a/docs/source/zh/consisid.md
+++ b/docs/source/zh/using-diffusers/consisid.md
@@ -1,4 +1,4 @@
-
+
+# 加载调度器与模型
+
+[[open-in-colab]]
+
+Diffusion管道是由可互换的调度器(schedulers)和模型(models)组成的集合,可通过混合搭配来定制特定用例的流程。调度器封装了整个去噪过程(如去噪步数和寻找去噪样本的算法),其本身不包含可训练参数,因此内存占用极低。模型则主要负责从含噪输入到较纯净样本的前向传播过程。
+
+本指南将展示如何加载调度器和模型来自定义流程。我们将全程使用[stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5)检查点,首先加载基础管道:
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+```
+
+通过`pipeline.scheduler`属性可查看当前管道使用的调度器:
+
+```python
+pipeline.scheduler
+PNDMScheduler {
+ "_class_name": "PNDMScheduler",
+ "_diffusers_version": "0.21.4",
+ "beta_end": 0.012,
+ "beta_schedule": "scaled_linear",
+ "beta_start": 0.00085,
+ "clip_sample": false,
+ "num_train_timesteps": 1000,
+ "set_alpha_to_one": false,
+ "skip_prk_steps": true,
+ "steps_offset": 1,
+ "timestep_spacing": "leading",
+ "trained_betas": null
+}
+```
+
+## 加载调度器
+
+调度器通过配置文件定义,同一配置文件可被多种调度器共享。使用[`SchedulerMixin.from_pretrained`]方法加载时,需指定`subfolder`参数以定位配置文件在仓库中的正确子目录。
+
+例如加载[`DDIMScheduler`]:
+
+```python
+from diffusers import DDIMScheduler, DiffusionPipeline
+
+ddim = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
+```
+
+然后将新调度器传入管道:
+
+```python
+pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+```
+
+## 调度器对比
+
+不同调度器各有优劣,难以定量评估哪个最适合您的流程。通常需要在去噪速度与质量之间权衡。我们建议尝试多种调度器以找到最佳方案。通过`pipeline.scheduler.compatibles`属性可查看兼容当前管道的所有调度器。
+
+下面我们使用相同提示词和随机种子,对比[`LMSDiscreteScheduler`]、[`EulerDiscreteScheduler`]、[`EulerAncestralDiscreteScheduler`]和[`DPMSolverMultistepScheduler`]的表现:
+
+```python
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+).to("cuda")
+
+prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
+generator = torch.Generator(device="cuda").manual_seed(8)
+```
+
+使用[`~ConfigMixin.from_config`]方法加载不同调度器的配置来切换管道调度器:
+
+
+
+
+[`LMSDiscreteScheduler`]通常能生成比默认调度器更高质量的图像。
+
+```python
+from diffusers import LMSDiscreteScheduler
+
+pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+[`EulerDiscreteScheduler`]仅需30步即可生成高质量图像。
+
+```python
+from diffusers import EulerDiscreteScheduler
+
+pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+[`EulerAncestralDiscreteScheduler`]同样可在30步内生成高质量图像。
+
+```python
+from diffusers import EulerAncestralDiscreteScheduler
+
+pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+[`DPMSolverMultistepScheduler`]在速度与质量间取得平衡,仅需20步即可生成优质图像。
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+
+
+
+
+
+
LMSDiscreteScheduler
+
+
+
+
EulerDiscreteScheduler
+
+
+
+
+
+
EulerAncestralDiscreteScheduler
+
+
+
+
DPMSolverMultistepScheduler
+
+
+
+多数生成图像质量相近,实际选择需根据具体场景测试多种调度器进行比较。
+
+### Flax调度器
+
+对比Flax调度器时,需额外将调度器状态加载到模型参数中。例如将[`FlaxStableDiffusionPipeline`]的默认调度器切换为超高效的[`FlaxDPMSolverMultistepScheduler`]:
+
+> [!警告]
+> [`FlaxLMSDiscreteScheduler`]和[`FlaxDDPMScheduler`]目前暂不兼容[`FlaxStableDiffusionPipeline`]。
+
+```python
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
+
+scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ subfolder="scheduler"
+)
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ scheduler=scheduler,
+ variant="bf16",
+ dtype=jax.numpy.bfloat16,
+)
+params["scheduler"] = scheduler_state
+```
+
+利用Flax对TPU的兼容性实现并行图像生成。需为每个设备复制模型参数,并分配输入数据:
+
+```python
+# 每个并行设备生成1张图像(TPUv2-8/TPUv3-8支持8设备并行)
+prompt = "一张宇航员在火星上骑马的高清照片,高分辨率,高画质。"
+num_samples = jax.device_count()
+prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
+
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 25
+
+# 分配输入和随机种子
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+```
+
+## 模型加载
+
+通过[`ModelMixin.from_pretrained`]方法加载模型,该方法会下载并缓存模型权重和配置的最新版本。若本地缓存已存在最新文件,则直接复用缓存而非重复下载。
+
+通过`subfolder`参数可从子目录加载模型。例如[stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5)的模型权重存储在[unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet)子目录中:
+
+```python
+from diffusers import UNet2DConditionModel
+
+unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
+```
+
+也可直接从[仓库](https://huggingface.co/google/ddpm-cifar10-32/tree/main)加载:
+
+```python
+from diffusers import UNet2DModel
+
+unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
+```
+
+加载和保存模型变体时,需在[`ModelMixin.from_pretrained`]和[`ModelMixin.save_pretrained`]中指定`variant`参数:
+
+```python
+from diffusers import UNet2DConditionModel
+
+unet = UNet2DConditionModel.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
+)
+unet.save_pretrained("./local-unet", variant="non_ema")
+```
+
+使用[`~ModelMixin.from_pretrained`]的`torch_dtype`参数指定模型加载精度:
+
+```python
+from diffusers import AutoModel
+
+unet = AutoModel.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
+)
+```
+
+也可使用[torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html)方法即时转换精度,但会转换所有权重(不同于`torch_dtype`参数会保留`_keep_in_fp32_modules`中的层)。这对某些必须保持fp32精度的层尤为重要(参见[示例](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374))。
diff --git a/examples/README.md b/examples/README.md
index 7cdf25999ac3..5c9a34582e87 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1,5 +1,5 @@
22.9.0-py37h89c1867_1\n",
+ "\n",
+ "\n",
+ "\n",
+ "Downloading and Extracting Packages\n",
+ "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n",
+ "Preparing transaction: / \b\bdone\n",
+ "Verifying transaction: \\ \b\bdone\n",
+ "Executing transaction: / \b\bdone\n",
+ "Retrieving notices: ...working... done\n"
+ ]
+ }
+ ],
+ "source": [
+ "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n",
+ "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QDS6FPZ0Tu5b"
+ },
+ "source": [
+ "Need to remove a pathspec for colab that specifies the incorrect cuda version."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dq1lxR10TtrR",
+ "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n"
+ ]
+ }
+ ],
+ "source": [
+ "!rm /usr/local/conda-meta/pinned"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Z1L3DdZOJB30"
+ },
+ "source": [
+ "Install torch geometric (used in the model later)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "D5ukfCOWfjzK",
+ "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
+ "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
+ "\n",
+ "## Package Plan ##\n",
+ "\n",
+ " environment location: /usr/local\n",
+ "\n",
+ " added / updated specs:\n",
+ " - pytorch-geometric=1.7.2\n",
+ "\n",
+ "\n",
+ "The following packages will be downloaded:\n",
+ "\n",
+ " package | build\n",
+ " ---------------------------|-----------------\n",
+ " decorator-4.4.2 | py_0 11 KB conda-forge\n",
+ " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n",
+ " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n",
+ " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n",
+ " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n",
+ " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n",
+ " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n",
+ " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n",
+ " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n",
+ " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n",
+ " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n",
+ " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n",
+ " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n",
+ " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n",
+ " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n",
+ " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n",
+ " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n",
+ " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n",
+ " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n",
+ " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n",
+ " ------------------------------------------------------------\n",
+ " Total: 55.9 MB\n",
+ "\n",
+ "The following NEW packages will be INSTALLED:\n",
+ "\n",
+ " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n",
+ " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n",
+ " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n",
+ " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n",
+ " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n",
+ " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n",
+ " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n",
+ " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n",
+ " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n",
+ " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n",
+ " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n",
+ " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n",
+ " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n",
+ " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n",
+ " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n",
+ " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n",
+ "\n",
+ "The following packages will be DOWNGRADED:\n",
+ "\n",
+ " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n",
+ "\n",
+ "\n",
+ "\n",
+ "Downloading and Extracting Packages\n",
+ "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n",
+ "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n",
+ "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n",
+ "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n",
+ "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n",
+ "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n",
+ "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n",
+ "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n",
+ "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n",
+ "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n",
+ "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n",
+ "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n",
+ "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n",
+ "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n",
+ "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n",
+ "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n",
+ "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n",
+ "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n",
+ "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n",
+ "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n",
+ "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n",
+ "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
+ "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
+ "Retrieving notices: ...working... done\n"
+ ]
+ }
+ ],
+ "source": [
+ "!conda install -c rusty1s pytorch-geometric=1.7.2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ppxv6Mdkalbc"
+ },
+ "source": [
+ "### Install Diffusers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "mgQA_XN-XGY2",
+ "outputId": "85392615-b6a4-4052-9d2a-79604be62c94"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/content\n",
+ "Cloning into 'diffusers'...\n",
+ "remote: Enumerating objects: 9298, done.\u001b[K\n",
+ "remote: Counting objects: 100% (40/40), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (23/23), done.\u001b[K\n",
+ "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n",
+ "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n",
+ "Resolving deltas: 100% (6168/6168), done.\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "%cd /content\n",
+ "\n",
+ "# install latest HF diffusers (will update to the release once added)\n",
+ "!git clone https://github.com/huggingface/diffusers.git\n",
+ "!pip install -q /content/diffusers\n",
+ "\n",
+ "# dependencies for diffusers\n",
+ "!pip install -q datasets transformers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LZO6AJKuJKO8"
+ },
+ "source": [
+ "Check that torch is installed correctly and utilizing the GPU in the colab"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 53
},
+ "id": "gZt7BNi1e1PA",
+ "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889"
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "JzDHaPU7I9Sn"
- },
- "source": [
- "Install pytorch requirements (this takes a few minutes, go grab yourself a coffee 🤗)"
- ]
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "True\n"
+ ]
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "JMxRjHhL7w8V",
- "outputId": "6ed511b3-9262-49e8-b340-08e76b05ebd8"
+ "data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\bdone\n",
- "Solving environment: \\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
- "\n",
- "## Package Plan ##\n",
- "\n",
- " environment location: /usr/local\n",
- "\n",
- " added / updated specs:\n",
- " - cudatoolkit=11.1\n",
- " - pytorch\n",
- " - torchaudio\n",
- " - torchvision\n",
- "\n",
- "\n",
- "The following packages will be downloaded:\n",
- "\n",
- " package | build\n",
- " ---------------------------|-----------------\n",
- " conda-22.9.0 | py37h89c1867_1 960 KB conda-forge\n",
- " ------------------------------------------------------------\n",
- " Total: 960 KB\n",
- "\n",
- "The following packages will be UPDATED:\n",
- "\n",
- " conda 4.14.0-py37h89c1867_0 --> 22.9.0-py37h89c1867_1\n",
- "\n",
- "\n",
- "\n",
- "Downloading and Extracting Packages\n",
- "conda-22.9.0 | 960 KB | : 100% 1.0/1 [00:00<00:00, 4.15it/s]\n",
- "Preparing transaction: / \b\bdone\n",
- "Verifying transaction: \\ \b\bdone\n",
- "Executing transaction: / \b\bdone\n",
- "Retrieving notices: ...working... done\n"
- ]
- }
- ],
- "source": [
- "!conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia\n",
- "# !conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge"
+ "text/plain": [
+ "'1.8.2'"
]
- },
- {
- "cell_type": "markdown",
- "source": [
- "Need to remove a pathspec for colab that specifies the incorrect cuda version."
- ],
- "metadata": {
- "id": "QDS6FPZ0Tu5b"
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "\n",
+ "print(torch.cuda.is_available())\n",
+ "torch.__version__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KLE7CqlfJNUO"
+ },
+ "source": [
+ "### Install Chemistry-specific Dependencies\n",
+ "\n",
+ "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0CPv_NvehRz3",
+ "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting rdkit\n",
+ " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n",
+ "Installing collected packages: rdkit\n",
+ "Successfully installed rdkit-2022.3.5\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install rdkit"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "88GaDbDPxJ5I"
+ },
+ "source": [
+ "### Get viewer from nglview\n",
+ "\n",
+ "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n",
+ "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n",
+ "The rdmol in this object is a source of ground truth for the generated molecules.\n",
+ "\n",
+ "You will use one rendering function from nglviewer later!\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "jcl8GCS2mz6t",
+ "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting nglview\n",
+ " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n",
+ "Collecting jupyterlab-widgets\n",
+ " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipywidgets>=7\n",
+ " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting widgetsnbextension~=4.0\n",
+ " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipython>=6.1.0\n",
+ " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ipykernel>=4.5.1\n",
+ " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting traitlets>=4.3.1\n",
+ " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n",
+ "Collecting pyzmq>=17\n",
+ " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting matplotlib-inline>=0.1\n",
+ " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n",
+ "Collecting tornado>=6.1\n",
+ " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting nest-asyncio\n",
+ " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n",
+ "Collecting debugpy>=1.0\n",
+ " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting psutil\n",
+ " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting jupyter-client>=6.1.12\n",
+ " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting pickleshare\n",
+ " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n",
+ "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n",
+ "Collecting backcall\n",
+ " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n",
+ "Collecting pexpect>4.3\n",
+ " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting pygments\n",
+ " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting jedi>=0.16\n",
+ " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n",
+ " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n",
+ "Collecting parso<0.9.0,>=0.8.0\n",
+ " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n",
+ "Collecting entrypoints\n",
+ " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n",
+ "Collecting jupyter-core>=4.9.2\n",
+ " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting ptyprocess>=0.5\n",
+ " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n",
+ "Collecting wcwidth\n",
+ " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n",
+ "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n",
+ "Building wheels for collected packages: nglview\n",
+ " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n",
+ " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n",
+ "Successfully built nglview\n",
+ "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n",
+ "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n",
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
+ "\u001b[0m"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "pexpect",
+ "pickleshare",
+ "wcwidth"
+ ]
+ }
}
- },
- {
- "cell_type": "code",
- "source": [
- "!rm /usr/local/conda-meta/pinned"
- ],
- "metadata": {
- "id": "dq1lxR10TtrR",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "ed9c5a71-b449-418f-abb7-072b74e7f6c8"
- },
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "rm: cannot remove '/usr/local/conda-meta/pinned': No such file or directory\n"
- ]
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Z1L3DdZOJB30"
- },
- "source": [
- "Install torch geometric (used in the model later)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "D5ukfCOWfjzK",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "8437485a-5aa6-4d53-8f7f-23517ac1ace6"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Collecting package metadata (current_repodata.json): - \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
- "Solving environment: | \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
- "\n",
- "## Package Plan ##\n",
- "\n",
- " environment location: /usr/local\n",
- "\n",
- " added / updated specs:\n",
- " - pytorch-geometric=1.7.2\n",
- "\n",
- "\n",
- "The following packages will be downloaded:\n",
- "\n",
- " package | build\n",
- " ---------------------------|-----------------\n",
- " decorator-4.4.2 | py_0 11 KB conda-forge\n",
- " googledrivedownloader-0.4 | pyhd3deb0d_1 7 KB conda-forge\n",
- " jinja2-3.1.2 | pyhd8ed1ab_1 99 KB conda-forge\n",
- " joblib-1.2.0 | pyhd8ed1ab_0 205 KB conda-forge\n",
- " markupsafe-2.1.1 | py37h540881e_1 22 KB conda-forge\n",
- " networkx-2.5.1 | pyhd8ed1ab_0 1.2 MB conda-forge\n",
- " pandas-1.2.3 | py37hdc94413_0 11.8 MB conda-forge\n",
- " pyparsing-3.0.9 | pyhd8ed1ab_0 79 KB conda-forge\n",
- " python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge\n",
- " python-louvain-0.15 | pyhd8ed1ab_1 13 KB conda-forge\n",
- " pytorch-cluster-1.5.9 |py37_torch_1.8.0_cu111 1.2 MB rusty1s\n",
- " pytorch-geometric-1.7.2 |py37_torch_1.8.0_cu111 445 KB rusty1s\n",
- " pytorch-scatter-2.0.8 |py37_torch_1.8.0_cu111 6.1 MB rusty1s\n",
- " pytorch-sparse-0.6.12 |py37_torch_1.8.0_cu111 2.9 MB rusty1s\n",
- " pytorch-spline-conv-1.2.1 |py37_torch_1.8.0_cu111 736 KB rusty1s\n",
- " pytz-2022.4 | pyhd8ed1ab_0 232 KB conda-forge\n",
- " scikit-learn-1.0.2 | py37hf9e9bfc_0 7.8 MB conda-forge\n",
- " scipy-1.7.3 | py37hf2a6cf1_0 21.8 MB conda-forge\n",
- " setuptools-59.8.0 | py37h89c1867_1 1.0 MB conda-forge\n",
- " threadpoolctl-3.1.0 | pyh8a188c0_0 18 KB conda-forge\n",
- " ------------------------------------------------------------\n",
- " Total: 55.9 MB\n",
- "\n",
- "The following NEW packages will be INSTALLED:\n",
- "\n",
- " decorator conda-forge/noarch::decorator-4.4.2-py_0 None\n",
- " googledrivedownlo~ conda-forge/noarch::googledrivedownloader-0.4-pyhd3deb0d_1 None\n",
- " jinja2 conda-forge/noarch::jinja2-3.1.2-pyhd8ed1ab_1 None\n",
- " joblib conda-forge/noarch::joblib-1.2.0-pyhd8ed1ab_0 None\n",
- " markupsafe conda-forge/linux-64::markupsafe-2.1.1-py37h540881e_1 None\n",
- " networkx conda-forge/noarch::networkx-2.5.1-pyhd8ed1ab_0 None\n",
- " pandas conda-forge/linux-64::pandas-1.2.3-py37hdc94413_0 None\n",
- " pyparsing conda-forge/noarch::pyparsing-3.0.9-pyhd8ed1ab_0 None\n",
- " python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 None\n",
- " python-louvain conda-forge/noarch::python-louvain-0.15-pyhd8ed1ab_1 None\n",
- " pytorch-cluster rusty1s/linux-64::pytorch-cluster-1.5.9-py37_torch_1.8.0_cu111 None\n",
- " pytorch-geometric rusty1s/linux-64::pytorch-geometric-1.7.2-py37_torch_1.8.0_cu111 None\n",
- " pytorch-scatter rusty1s/linux-64::pytorch-scatter-2.0.8-py37_torch_1.8.0_cu111 None\n",
- " pytorch-sparse rusty1s/linux-64::pytorch-sparse-0.6.12-py37_torch_1.8.0_cu111 None\n",
- " pytorch-spline-co~ rusty1s/linux-64::pytorch-spline-conv-1.2.1-py37_torch_1.8.0_cu111 None\n",
- " pytz conda-forge/noarch::pytz-2022.4-pyhd8ed1ab_0 None\n",
- " scikit-learn conda-forge/linux-64::scikit-learn-1.0.2-py37hf9e9bfc_0 None\n",
- " scipy conda-forge/linux-64::scipy-1.7.3-py37hf2a6cf1_0 None\n",
- " threadpoolctl conda-forge/noarch::threadpoolctl-3.1.0-pyh8a188c0_0 None\n",
- "\n",
- "The following packages will be DOWNGRADED:\n",
- "\n",
- " setuptools 65.3.0-py37h89c1867_0 --> 59.8.0-py37h89c1867_1 None\n",
- "\n",
- "\n",
- "\n",
- "Downloading and Extracting Packages\n",
- "scikit-learn-1.0.2 | 7.8 MB | : 100% 1.0/1 [00:01<00:00, 1.37s/it] \n",
- "pytorch-scatter-2.0. | 6.1 MB | : 100% 1.0/1 [00:06<00:00, 6.18s/it]\n",
- "pytorch-geometric-1. | 445 KB | : 100% 1.0/1 [00:02<00:00, 2.53s/it]\n",
- "scipy-1.7.3 | 21.8 MB | : 100% 1.0/1 [00:03<00:00, 3.06s/it]\n",
- "python-dateutil-2.8. | 240 KB | : 100% 1.0/1 [00:00<00:00, 21.48it/s]\n",
- "pytorch-spline-conv- | 736 KB | : 100% 1.0/1 [00:01<00:00, 1.00s/it]\n",
- "pytorch-sparse-0.6.1 | 2.9 MB | : 100% 1.0/1 [00:07<00:00, 7.51s/it]\n",
- "pyparsing-3.0.9 | 79 KB | : 100% 1.0/1 [00:00<00:00, 26.32it/s]\n",
- "pytorch-cluster-1.5. | 1.2 MB | : 100% 1.0/1 [00:02<00:00, 2.78s/it]\n",
- "jinja2-3.1.2 | 99 KB | : 100% 1.0/1 [00:00<00:00, 20.28it/s]\n",
- "decorator-4.4.2 | 11 KB | : 100% 1.0/1 [00:00<00:00, 21.57it/s]\n",
- "joblib-1.2.0 | 205 KB | : 100% 1.0/1 [00:00<00:00, 15.04it/s]\n",
- "pytz-2022.4 | 232 KB | : 100% 1.0/1 [00:00<00:00, 10.21it/s]\n",
- "python-louvain-0.15 | 13 KB | : 100% 1.0/1 [00:00<00:00, 3.34it/s]\n",
- "googledrivedownloade | 7 KB | : 100% 1.0/1 [00:00<00:00, 3.33it/s]\n",
- "threadpoolctl-3.1.0 | 18 KB | : 100% 1.0/1 [00:00<00:00, 29.40it/s]\n",
- "markupsafe-2.1.1 | 22 KB | : 100% 1.0/1 [00:00<00:00, 28.62it/s]\n",
- "pandas-1.2.3 | 11.8 MB | : 100% 1.0/1 [00:02<00:00, 2.08s/it] \n",
- "networkx-2.5.1 | 1.2 MB | : 100% 1.0/1 [00:01<00:00, 1.39s/it]\n",
- "setuptools-59.8.0 | 1.0 MB | : 100% 1.0/1 [00:00<00:00, 4.25it/s]\n",
- "Preparing transaction: / \b\b- \b\b\\ \b\bdone\n",
- "Verifying transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\bdone\n",
- "Executing transaction: / \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\b\\ \b\b| \b\b/ \b\b- \b\bdone\n",
- "Retrieving notices: ...working... done\n"
- ]
- }
- ],
- "source": [
- "!conda install -c rusty1s pytorch-geometric=1.7.2"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ppxv6Mdkalbc"
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "!pip install nglview"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8t8_e_uVLdKB"
+ },
+ "source": [
+ "## Create a diffusion model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "G0rMncVtNSqU"
+ },
+ "source": [
+ "### Model class(es)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "L5FEXz5oXkzt"
+ },
+ "source": [
+ "Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-3-P4w5sXkRU"
+ },
+ "outputs": [],
+ "source": [
+ "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n",
+ "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n",
+ "from dataclasses import dataclass\n",
+ "from typing import Callable, Tuple, Union\n",
+ "\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "from torch import Tensor, nn\n",
+ "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n",
+ "from torch_geometric.nn import MessagePassing, radius, radius_graph\n",
+ "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n",
+ "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n",
+ "from torch_scatter import scatter_add\n",
+ "from torch_sparse import SparseTensor, coalesce\n",
+ "\n",
+ "from diffusers.configuration_utils import ConfigMixin, register_to_config\n",
+ "from diffusers.modeling_utils import ModelMixin\n",
+ "from diffusers.utils import BaseOutput"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "EzJQXPN_XrMX"
+ },
+ "source": [
+ "Helper classes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "oR1Y56QiLY90"
+ },
+ "outputs": [],
+ "source": [
+ "@dataclass\n",
+ "class MoleculeGNNOutput(BaseOutput):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n",
+ " Hidden states output. Output of last layer of model.\n",
+ " \"\"\"\n",
+ "\n",
+ " sample: torch.Tensor\n",
+ "\n",
+ "\n",
+ "class MultiLayerPerceptron(nn.Module):\n",
+ " \"\"\"\n",
+ " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n",
+ " Args:\n",
+ " input_dim (int): input dimension\n",
+ " hidden_dim (list of int): hidden dimensions\n",
+ " activation (str or function, optional): activation function\n",
+ " dropout (float, optional): dropout rate\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n",
+ " super(MultiLayerPerceptron, self).__init__()\n",
+ "\n",
+ " self.dims = [input_dim] + hidden_dims\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " print(f\"Warning, activation passed {activation} is not string and ignored\")\n",
+ " self.activation = None\n",
+ " if dropout > 0:\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " else:\n",
+ " self.dropout = None\n",
+ "\n",
+ " self.layers = nn.ModuleList()\n",
+ " for i in range(len(self.dims) - 1):\n",
+ " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n",
+ "\n",
+ " def forward(self, x):\n",
+ " \"\"\"\"\"\"\n",
+ " for i, layer in enumerate(self.layers):\n",
+ " x = layer(x)\n",
+ " if i < len(self.layers) - 1:\n",
+ " if self.activation:\n",
+ " x = self.activation(x)\n",
+ " if self.dropout:\n",
+ " x = self.dropout(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "class ShiftedSoftplus(torch.nn.Module):\n",
+ " def __init__(self):\n",
+ " super(ShiftedSoftplus, self).__init__()\n",
+ " self.shift = torch.log(torch.tensor(2.0)).item()\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return F.softplus(x) - self.shift\n",
+ "\n",
+ "\n",
+ "class CFConv(MessagePassing):\n",
+ " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n",
+ " super(CFConv, self).__init__(aggr=\"add\")\n",
+ " self.lin1 = Linear(in_channels, num_filters, bias=False)\n",
+ " self.lin2 = Linear(num_filters, out_channels)\n",
+ " self.nn = mlp\n",
+ " self.cutoff = cutoff\n",
+ " self.smooth = smooth\n",
+ "\n",
+ " self.reset_parameters()\n",
+ "\n",
+ " def reset_parameters(self):\n",
+ " torch.nn.init.xavier_uniform_(self.lin1.weight)\n",
+ " torch.nn.init.xavier_uniform_(self.lin2.weight)\n",
+ " self.lin2.bias.data.fill_(0)\n",
+ "\n",
+ " def forward(self, x, edge_index, edge_length, edge_attr):\n",
+ " if self.smooth:\n",
+ " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n",
+ " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n",
+ " else:\n",
+ " C = (edge_length <= self.cutoff).float()\n",
+ " W = self.nn(edge_attr) * C.view(-1, 1)\n",
+ "\n",
+ " x = self.lin1(x)\n",
+ " x = self.propagate(edge_index, x=x, W=W)\n",
+ " x = self.lin2(x)\n",
+ " return x\n",
+ "\n",
+ " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n",
+ " return x_j * W\n",
+ "\n",
+ "\n",
+ "class InteractionBlock(torch.nn.Module):\n",
+ " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n",
+ " super(InteractionBlock, self).__init__()\n",
+ " mlp = Sequential(\n",
+ " Linear(num_gaussians, num_filters),\n",
+ " ShiftedSoftplus(),\n",
+ " Linear(num_filters, num_filters),\n",
+ " )\n",
+ " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n",
+ " self.act = ShiftedSoftplus()\n",
+ " self.lin = Linear(hidden_channels, hidden_channels)\n",
+ "\n",
+ " def forward(self, x, edge_index, edge_length, edge_attr):\n",
+ " x = self.conv(x, edge_index, edge_length, edge_attr)\n",
+ " x = self.act(x)\n",
+ " x = self.lin(x)\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "class SchNetEncoder(Module):\n",
+ " def __init__(\n",
+ " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n",
+ " ):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.hidden_channels = hidden_channels\n",
+ " self.num_filters = num_filters\n",
+ " self.num_interactions = num_interactions\n",
+ " self.cutoff = cutoff\n",
+ "\n",
+ " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n",
+ "\n",
+ " self.interactions = ModuleList()\n",
+ " for _ in range(num_interactions):\n",
+ " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n",
+ " self.interactions.append(block)\n",
+ "\n",
+ " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n",
+ " if embed_node:\n",
+ " assert z.dim() == 1 and z.dtype == torch.long\n",
+ " h = self.embedding(z)\n",
+ " else:\n",
+ " h = z\n",
+ " for interaction in self.interactions:\n",
+ " h = h + interaction(h, edge_index, edge_length, edge_attr)\n",
+ "\n",
+ " return h\n",
+ "\n",
+ "\n",
+ "class GINEConv(MessagePassing):\n",
+ " \"\"\"\n",
+ " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n",
+ " https://huggingface.co/papers/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n",
+ " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n",
+ " self.nn = mlp\n",
+ " self.initial_eps = eps\n",
+ "\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " self.activation = None\n",
+ "\n",
+ " if train_eps:\n",
+ " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n",
+ " else:\n",
+ " self.register_buffer(\"eps\", torch.Tensor([eps]))\n",
+ "\n",
+ " def forward(\n",
+ " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n",
+ " ) -> torch.Tensor:\n",
+ " \"\"\"\"\"\"\n",
+ " if isinstance(x, torch.Tensor):\n",
+ " x: OptPairTensor = (x, x)\n",
+ "\n",
+ " # Node and edge feature dimensionalites need to match.\n",
+ " if isinstance(edge_index, torch.Tensor):\n",
+ " assert edge_attr is not None\n",
+ " assert x[0].size(-1) == edge_attr.size(-1)\n",
+ " elif isinstance(edge_index, SparseTensor):\n",
+ " assert x[0].size(-1) == edge_index.size(-1)\n",
+ "\n",
+ " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n",
+ " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n",
+ "\n",
+ " x_r = x[1]\n",
+ " if x_r is not None:\n",
+ " out += (1 + self.eps) * x_r\n",
+ "\n",
+ " return self.nn(out)\n",
+ "\n",
+ " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n",
+ " if self.activation:\n",
+ " return self.activation(x_j + edge_attr)\n",
+ " else:\n",
+ " return x_j + edge_attr\n",
+ "\n",
+ " def __repr__(self):\n",
+ " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n",
+ "\n",
+ "\n",
+ "class GINEncoder(torch.nn.Module):\n",
+ " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.hidden_dim = hidden_dim\n",
+ " self.num_convs = num_convs\n",
+ " self.short_cut = short_cut\n",
+ " self.concat_hidden = concat_hidden\n",
+ " self.node_emb = nn.Embedding(100, hidden_dim)\n",
+ "\n",
+ " if isinstance(activation, str):\n",
+ " self.activation = getattr(F, activation)\n",
+ " else:\n",
+ " self.activation = None\n",
+ "\n",
+ " self.convs = nn.ModuleList()\n",
+ " for i in range(self.num_convs):\n",
+ " self.convs.append(\n",
+ " GINEConv(\n",
+ " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n",
+ " activation=activation,\n",
+ " )\n",
+ " )\n",
+ "\n",
+ " def forward(self, z, edge_index, edge_attr):\n",
+ " \"\"\"\n",
+ " Input:\n",
+ " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n",
+ " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n",
+ " Output:\n",
+ " node_feature: graph feature\n",
+ " \"\"\"\n",
+ "\n",
+ " node_attr = self.node_emb(z) # (num_node, hidden)\n",
+ "\n",
+ " hiddens = []\n",
+ " conv_input = node_attr # (num_node, hidden)\n",
+ "\n",
+ " for conv_idx, conv in enumerate(self.convs):\n",
+ " hidden = conv(conv_input, edge_index, edge_attr)\n",
+ " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n",
+ " hidden = self.activation(hidden)\n",
+ " assert hidden.shape == conv_input.shape\n",
+ " if self.short_cut and hidden.shape == conv_input.shape:\n",
+ " hidden += conv_input\n",
+ "\n",
+ " hiddens.append(hidden)\n",
+ " conv_input = hidden\n",
+ "\n",
+ " if self.concat_hidden:\n",
+ " node_feature = torch.cat(hiddens, dim=-1)\n",
+ " else:\n",
+ " node_feature = hiddens[-1]\n",
+ "\n",
+ " return node_feature\n",
+ "\n",
+ "\n",
+ "class MLPEdgeEncoder(Module):\n",
+ " def __init__(self, hidden_dim=100, activation=\"relu\"):\n",
+ " super().__init__()\n",
+ " self.hidden_dim = hidden_dim\n",
+ " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n",
+ " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n",
+ "\n",
+ " @property\n",
+ " def out_channels(self):\n",
+ " return self.hidden_dim\n",
+ "\n",
+ " def forward(self, edge_length, edge_type):\n",
+ " \"\"\"\n",
+ " Input:\n",
+ " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n",
+ " Returns:\n",
+ " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n",
+ " \"\"\"\n",
+ " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n",
+ " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n",
+ " return d_emb * edge_attr # (num_edge, hidden)\n",
+ "\n",
+ "\n",
+ "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n",
+ " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n",
+ " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n",
+ " return h_pair\n",
+ "\n",
+ "\n",
+ "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " num_nodes: Number of atoms.\n",
+ " edge_index: Bond indices of the original graph.\n",
+ " edge_type: Bond types of the original graph.\n",
+ " order: Extension order.\n",
+ " Returns:\n",
+ " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n",
+ " \"\"\"\n",
+ "\n",
+ " def binarize(x):\n",
+ " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n",
+ "\n",
+ " def get_higher_order_adj_matrix(adj, order):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " adj: (N, N)\n",
+ " type_mat: (N, N)\n",
+ " Returns:\n",
+ " Following attributes will be updated:\n",
+ " - edge_index\n",
+ " - edge_type\n",
+ " Following attributes will be added to the data object:\n",
+ " - bond_edge_index: Original edge_index.\n",
+ " \"\"\"\n",
+ " adj_mats = [\n",
+ " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n",
+ " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n",
+ " ]\n",
+ "\n",
+ " for i in range(2, order + 1):\n",
+ " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n",
+ " order_mat = torch.zeros_like(adj)\n",
+ "\n",
+ " for i in range(1, order + 1):\n",
+ " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n",
+ "\n",
+ " return order_mat\n",
+ "\n",
+ " num_types = 22\n",
+ " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n",
+ " # from rdkit.Chem.rdchem import BondType as BT\n",
+ " N = num_nodes\n",
+ " adj = to_dense_adj(edge_index).squeeze(0)\n",
+ " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n",
+ "\n",
+ " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n",
+ " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n",
+ " assert (type_mat * type_highorder == 0).all()\n",
+ " type_new = type_mat + type_highorder\n",
+ "\n",
+ " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n",
+ " _, edge_order = dense_to_sparse(adj_order)\n",
+ "\n",
+ " # data.bond_edge_index = data.edge_index # Save original edges\n",
+ " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n",
+ "\n",
+ " return new_edge_index, new_edge_type\n",
+ "\n",
+ "\n",
+ "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n",
+ " assert edge_type.dim() == 1\n",
+ " N = pos.size(0)\n",
+ "\n",
+ " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n",
+ "\n",
+ " if is_sidechain is None:\n",
+ " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n",
+ " else:\n",
+ " # fetch sidechain and its batch index\n",
+ " is_sidechain = is_sidechain.bool()\n",
+ " dummy_index = torch.arange(pos.size(0), device=pos.device)\n",
+ " sidechain_pos = pos[is_sidechain]\n",
+ " sidechain_index = dummy_index[is_sidechain]\n",
+ " sidechain_batch = batch[is_sidechain]\n",
+ "\n",
+ " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n",
+ " r_edge_index_x = assign_index[1]\n",
+ " r_edge_index_y = assign_index[0]\n",
+ " r_edge_index_y = sidechain_index[r_edge_index_y]\n",
+ "\n",
+ " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n",
+ " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n",
+ " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n",
+ " # delete self loop\n",
+ " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n",
+ "\n",
+ " rgraph_adj = torch.sparse.LongTensor(\n",
+ " rgraph_edge_index,\n",
+ " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n",
+ " torch.Size([N, N]),\n",
+ " )\n",
+ "\n",
+ " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n",
+ "\n",
+ " new_edge_index = composed_adj.indices()\n",
+ " new_edge_type = composed_adj.values().long()\n",
+ "\n",
+ " return new_edge_index, new_edge_type\n",
+ "\n",
+ "\n",
+ "def extend_graph_order_radius(\n",
+ " num_nodes,\n",
+ " pos,\n",
+ " edge_index,\n",
+ " edge_type,\n",
+ " batch,\n",
+ " order=3,\n",
+ " cutoff=10.0,\n",
+ " extend_order=True,\n",
+ " extend_radius=True,\n",
+ " is_sidechain=None,\n",
+ "):\n",
+ " if extend_order:\n",
+ " edge_index, edge_type = _extend_graph_order(\n",
+ " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n",
+ " )\n",
+ "\n",
+ " if extend_radius:\n",
+ " edge_index, edge_type = _extend_to_radius_graph(\n",
+ " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n",
+ " )\n",
+ "\n",
+ " return edge_index, edge_type\n",
+ "\n",
+ "\n",
+ "def get_distance(pos, edge_index):\n",
+ " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n",
+ "\n",
+ "\n",
+ "def graph_field_network(score_d, pos, edge_index, edge_length):\n",
+ " \"\"\"\n",
+ " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n",
+ " 5-7 of the GeoDiff Paper https://huggingface.co/papers/2203.02923\n",
+ " \"\"\"\n",
+ " N = pos.size(0)\n",
+ " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n",
+ " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n",
+ " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n",
+ " ) # (N, 3)\n",
+ " return score_pos\n",
+ "\n",
+ "\n",
+ "def clip_norm(vec, limit, p=2):\n",
+ " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n",
+ " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n",
+ " return vec * denom\n",
+ "\n",
+ "\n",
+ "def is_local_edge(edge_type):\n",
+ " return edge_type > 0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QWrHJFcYXyUB"
+ },
+ "source": [
+ "Main model class!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "MCeZA1qQXzoK"
+ },
+ "outputs": [],
+ "source": [
+ "class MoleculeGNN(ModelMixin, ConfigMixin):\n",
+ " @register_to_config\n",
+ " def __init__(\n",
+ " self,\n",
+ " hidden_dim=128,\n",
+ " num_convs=6,\n",
+ " num_convs_local=4,\n",
+ " cutoff=10.0,\n",
+ " mlp_act=\"relu\",\n",
+ " edge_order=3,\n",
+ " edge_encoder=\"mlp\",\n",
+ " smooth_conv=True,\n",
+ " ):\n",
+ " super().__init__()\n",
+ " self.cutoff = cutoff\n",
+ " self.edge_encoder = edge_encoder\n",
+ " self.edge_order = edge_order\n",
+ "\n",
+ " \"\"\"\n",
+ " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n",
+ " in SchNetEncoder\n",
+ " \"\"\"\n",
+ " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
+ " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
+ "\n",
+ " \"\"\"\n",
+ " The graph neural network that extracts node-wise features.\n",
+ " \"\"\"\n",
+ " self.encoder_global = SchNetEncoder(\n",
+ " hidden_channels=hidden_dim,\n",
+ " num_filters=hidden_dim,\n",
+ " num_interactions=num_convs,\n",
+ " edge_channels=self.edge_encoder_global.out_channels,\n",
+ " cutoff=cutoff,\n",
+ " smooth=smooth_conv,\n",
+ " )\n",
+ " self.encoder_local = GINEncoder(\n",
+ " hidden_dim=hidden_dim,\n",
+ " num_convs=num_convs_local,\n",
+ " )\n",
+ "\n",
+ " \"\"\"\n",
+ " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n",
+ " gradients w.r.t. edge_length (out_dim = 1).\n",
+ " \"\"\"\n",
+ " self.grad_global_dist_mlp = MultiLayerPerceptron(\n",
+ " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
+ " )\n",
+ "\n",
+ " self.grad_local_dist_mlp = MultiLayerPerceptron(\n",
+ " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
+ " )\n",
+ "\n",
+ " \"\"\"\n",
+ " Incorporate parameters together\n",
+ " \"\"\"\n",
+ " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n",
+ " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n",
+ "\n",
+ " def _forward(\n",
+ " self,\n",
+ " atom_type,\n",
+ " pos,\n",
+ " bond_index,\n",
+ " bond_type,\n",
+ " batch,\n",
+ " time_step, # NOTE, model trained without timestep performed best\n",
+ " edge_index=None,\n",
+ " edge_type=None,\n",
+ " edge_length=None,\n",
+ " return_edges=False,\n",
+ " extend_order=True,\n",
+ " extend_radius=True,\n",
+ " is_sidechain=None,\n",
+ " ):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " atom_type: Types of atoms, (N, ).\n",
+ " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n",
+ " bond_type: Bond types, (E, ).\n",
+ " batch: Node index to graph index, (N, ).\n",
+ " \"\"\"\n",
+ " N = atom_type.size(0)\n",
+ " if edge_index is None or edge_type is None or edge_length is None:\n",
+ " edge_index, edge_type = extend_graph_order_radius(\n",
+ " num_nodes=N,\n",
+ " pos=pos,\n",
+ " edge_index=bond_index,\n",
+ " edge_type=bond_type,\n",
+ " batch=batch,\n",
+ " order=self.edge_order,\n",
+ " cutoff=self.cutoff,\n",
+ " extend_order=extend_order,\n",
+ " extend_radius=extend_radius,\n",
+ " is_sidechain=is_sidechain,\n",
+ " )\n",
+ " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n",
+ " local_edge_mask = is_local_edge(edge_type) # (E, )\n",
+ "\n",
+ " # with the parameterization of NCSNv2\n",
+ " # DDPM loss implicit handle the noise variance scale conditioning\n",
+ " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n",
+ "\n",
+ " # Encoding global\n",
+ " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
+ "\n",
+ " # Global\n",
+ " node_attr_global = self.encoder_global(\n",
+ " z=atom_type,\n",
+ " edge_index=edge_index,\n",
+ " edge_length=edge_length,\n",
+ " edge_attr=edge_attr_global,\n",
+ " )\n",
+ " # Assemble pairwise features\n",
+ " h_pair_global = assemble_atom_pair_feature(\n",
+ " node_attr=node_attr_global,\n",
+ " edge_index=edge_index,\n",
+ " edge_attr=edge_attr_global,\n",
+ " ) # (E_global, 2H)\n",
+ " # Invariant features of edges (radius graph, global)\n",
+ " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n",
+ "\n",
+ " # Encoding local\n",
+ " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
+ " # edge_attr += temb_edge\n",
+ "\n",
+ " # Local\n",
+ " node_attr_local = self.encoder_local(\n",
+ " z=atom_type,\n",
+ " edge_index=edge_index[:, local_edge_mask],\n",
+ " edge_attr=edge_attr_local[local_edge_mask],\n",
+ " )\n",
+ " # Assemble pairwise features\n",
+ " h_pair_local = assemble_atom_pair_feature(\n",
+ " node_attr=node_attr_local,\n",
+ " edge_index=edge_index[:, local_edge_mask],\n",
+ " edge_attr=edge_attr_local[local_edge_mask],\n",
+ " ) # (E_local, 2H)\n",
+ "\n",
+ " # Invariant features of edges (bond graph, local)\n",
+ " if isinstance(sigma_edge, torch.Tensor):\n",
+ " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n",
+ " 1.0 / sigma_edge[local_edge_mask]\n",
+ " ) # (E_local, 1)\n",
+ " else:\n",
+ " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n",
+ "\n",
+ " if return_edges:\n",
+ " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n",
+ " else:\n",
+ " return edge_inv_global, edge_inv_local\n",
+ "\n",
+ " def forward(\n",
+ " self,\n",
+ " sample,\n",
+ " timestep: Union[torch.Tensor, float, int],\n",
+ " return_dict: bool = True,\n",
+ " sigma=1.0,\n",
+ " global_start_sigma=0.5,\n",
+ " w_global=1.0,\n",
+ " extend_order=False,\n",
+ " extend_radius=True,\n",
+ " clip_local=None,\n",
+ " clip_global=1000.0,\n",
+ " ) -> Union[MoleculeGNNOutput, Tuple]:\n",
+ " r\"\"\"\n",
+ " Args:\n",
+ " sample: packed torch geometric object\n",
+ " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
+ " return_dict (`bool`, *optional*, defaults to `True`):\n",
+ " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n",
+ " Returns:\n",
+ " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n",
+ " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n",
+ " \"\"\"\n",
+ "\n",
+ " # unpack sample\n",
+ " atom_type = sample.atom_type\n",
+ " bond_index = sample.edge_index\n",
+ " bond_type = sample.edge_type\n",
+ " num_graphs = sample.num_graphs\n",
+ " pos = sample.pos\n",
+ "\n",
+ " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n",
+ "\n",
+ " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n",
+ " atom_type=atom_type,\n",
+ " pos=sample.pos,\n",
+ " bond_index=bond_index,\n",
+ " bond_type=bond_type,\n",
+ " batch=sample.batch,\n",
+ " time_step=timesteps,\n",
+ " return_edges=True,\n",
+ " extend_order=extend_order,\n",
+ " extend_radius=extend_radius,\n",
+ " ) # (E_global, 1), (E_local, 1)\n",
+ "\n",
+ " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n",
+ " node_eq_local = graph_field_network(\n",
+ " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n",
+ " )\n",
+ " if clip_local is not None:\n",
+ " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n",
+ "\n",
+ " # Global\n",
+ " if sigma < global_start_sigma:\n",
+ " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n",
+ " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n",
+ " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n",
+ " else:\n",
+ " node_eq_global = 0\n",
+ "\n",
+ " # Sum\n",
+ " eps_pos = node_eq_local + node_eq_global * w_global\n",
+ "\n",
+ " if not return_dict:\n",
+ " return (-eps_pos,)\n",
+ "\n",
+ " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CCIrPYSJj9wd"
+ },
+ "source": [
+ "### Load pretrained model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "YdrAr6Ch--Ab"
+ },
+ "source": [
+ "#### Load a model\n",
+ "The model used is a design an\n",
+ "equivariant convolutional layer, named graph field network (GFN).\n",
+ "\n",
+ "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 172,
+ "referenced_widgets": [
+ "d90f304e9560472eacfbdd11e46765eb",
+ "1c6246f15b654f4daa11c9bcf997b78c",
+ "c2321b3bff6f490ca12040a20308f555",
+ "b7feb522161f4cf4b7cc7c1a078ff12d",
+ "e2d368556e494ae7ae4e2e992af2cd4f",
+ "bbef741e76ec41b7ab7187b487a383df",
+ "561f742d418d4721b0670cc8dd62e22c",
+ "872915dd1bb84f538c44e26badabafdd",
+ "d022575f1fa2446d891650897f187b4d",
+ "fdc393f3468c432aa0ada05e238a5436",
+ "2c9362906e4b40189f16d14aa9a348da",
+ "6010fc8daa7a44d5aec4b830ec2ebaa1",
+ "7e0bb1b8d65249d3974200686b193be2",
+ "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
+ "6526646be5ed415c84d1245b040e629b",
+ "24d31fc3576e43dd9f8301d2ef3a37ab",
+ "2918bfaadc8d4b1a9832522c40dfefb8",
+ "a4bfdca35cc54dae8812720f1b276a08",
+ "e4901541199b45c6a18824627692fc39",
+ "f915cf874246446595206221e900b2fe",
+ "a9e388f22a9742aaaf538e22575c9433",
+ "42f6c3db29d7484ba6b4f73590abd2f4"
+ ]
+ },
+ "id": "DyCo0nsqjbml",
+ "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "d90f304e9560472eacfbdd11e46765eb",
+ "version_major": 2,
+ "version_minor": 0
},
- "source": [
- "### Install Diffusers"
+ "text/plain": [
+ "Downloading: 0%| | 0.00/3.27M [00:00, ?B/s]"
]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "mgQA_XN-XGY2",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "85392615-b6a4-4052-9d2a-79604be62c94"
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6010fc8daa7a44d5aec4b830ec2ebaa1",
+ "version_major": 2,
+ "version_minor": 0
},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "/content\n",
- "Cloning into 'diffusers'...\n",
- "remote: Enumerating objects: 9298, done.\u001b[K\n",
- "remote: Counting objects: 100% (40/40), done.\u001b[K\n",
- "remote: Compressing objects: 100% (23/23), done.\u001b[K\n",
- "remote: Total 9298 (delta 17), reused 23 (delta 11), pack-reused 9258\u001b[K\n",
- "Receiving objects: 100% (9298/9298), 7.38 MiB | 5.28 MiB/s, done.\n",
- "Resolving deltas: 100% (6168/6168), done.\n",
- " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m757.0/757.0 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m163.5/163.5 kB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 kB\u001b[0m \u001b[31m5.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m596.3/596.3 kB\u001b[0m \u001b[31m51.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m432.7/432.7 kB\u001b[0m \u001b[31m36.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.3/5.3 MB\u001b[0m \u001b[31m90.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.3/35.3 MB\u001b[0m \u001b[31m39.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.1/115.1 kB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m948.0/948.0 kB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.2/212.2 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.8/95.8 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.8/140.8 kB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m104.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m148.0/148.0 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m231.3/231.3 kB\u001b[0m \u001b[31m30.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.8/94.8 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
- }
- ],
- "source": [
- "%cd /content\n",
- "\n",
- "# install latest HF diffusers (will update to the release once added)\n",
- "!git clone https://github.com/huggingface/diffusers.git\n",
- "!pip install -q /content/diffusers\n",
- "\n",
- "# dependencies for diffusers\n",
- "!pip install -q datasets transformers"
+ "text/plain": [
+ "Downloading: 0%| | 0.00/401 [00:00, ?B/s]"
]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "The config attributes {'type': 'diffusion', 'network': 'dualenc', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'num_diffusion_timesteps': 5000} were passed to MoleculeGNN, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
+ "Some weights of the model checkpoint at fusing/gfn-molecule-gen-drugs were not used when initializing MoleculeGNN: ['betas', 'alphas']\n",
+ "- This IS expected if you are initializing MoleculeGNN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing MoleculeGNN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
+ ]
+ }
+ ],
+ "source": [
+ "DEVICE = \"cuda\"\n",
+ "model = MoleculeGNN.from_pretrained(\"fusing/gfn-molecule-gen-drugs\").to(DEVICE)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HdclRaqoUWUD"
+ },
+ "source": [
+ "The warnings above are because the pre-trained model was uploaded before cleaning the code!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "PlOkPySoJ1m9"
+ },
+ "source": [
+ "#### Create scheduler\n",
+ "Note, other schedulers are used in the paper for slightly improved performance over DDPM."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "nNHnIk9CkAb2"
+ },
+ "outputs": [],
+ "source": [
+ "from diffusers import DDPMScheduler"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "RnDJdDBztjFF"
+ },
+ "outputs": [],
+ "source": [
+ "num_timesteps = 1000\n",
+ "scheduler = DDPMScheduler(\n",
+ " num_train_timesteps=num_timesteps, beta_schedule=\"sigmoid\", beta_start=1e-7, beta_end=2e-3, clip_sample=False\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1vh3fpSAflkL"
+ },
+ "source": [
+ "### Get a dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "B6qzaGjVKFVk"
+ },
+ "source": [
+ "Grab a google tool so we can upload our data directly. Note you need to download the data from ***this [file](https://huggingface.co/datasets/fusing/geodiff-example-data/blob/main/data/molecules.pkl)***\n",
+ "\n",
+ "(direct downloading from the hub does not yet work for this datatype)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "jbLl3EJdgj3x"
+ },
+ "outputs": [],
+ "source": [
+ "# from google.colab import files"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "E591lVuTgxPE"
+ },
+ "outputs": [],
+ "source": [
+ "# uploaded = files.upload()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KUNxfK3ln98Q"
+ },
+ "source": [
+ "Load the dataset with torch."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "7L4iOShTpcQX",
+ "outputId": "7f2dcd29-493e-44de-98d1-3ad50f109a4a"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--2022-10-12 18:32:19-- https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
+ "Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...\n",
+ "Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 127774 (125K) [application/octet-stream]\n",
+ "Saving to: ‘molecules.pkl’\n",
+ "\n",
+ "molecules.pkl 100%[===================>] 124.78K 180KB/s in 0.7s \n",
+ "\n",
+ "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "\n",
+ "\n",
+ "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
+ "dataset = torch.load(\"/content/molecules.pkl\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QZcmy1EvKQRk"
+ },
+ "source": [
+ "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
},
+ "id": "JVjz6iH_H6Eh",
+ "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe"
+ },
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "LZO6AJKuJKO8"
- },
- "source": [
- "Check that torch is installed correctly and utilizing the GPU in the colab"
+ "data": {
+ "text/plain": [
+ "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "gZt7BNi1e1PA",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 53
- },
- "outputId": "a0e1832c-9c02-49aa-cff8-1339e6cdc889"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "True\n"
- ]
- },
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "'1.8.2'"
- ],
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "string"
- }
- },
- "metadata": {},
- "execution_count": 8
- }
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vHNiZAUxNgoy"
+ },
+ "source": [
+ "## Run the diffusion process"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jZ1KZrxKqENg"
+ },
+ "source": [
+ "#### Helper Functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "s240tYueqKKf"
+ },
+ "outputs": [],
+ "source": [
+ "import copy\n",
+ "import os\n",
+ "\n",
+ "from torch_geometric.data import Batch, Data\n",
+ "from torch_scatter import scatter_mean\n",
+ "from tqdm import tqdm\n",
+ "\n",
+ "\n",
+ "def repeat_data(data: Data, num_repeat) -> Batch:\n",
+ " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n",
+ " return Batch.from_data_list(datas)\n",
+ "\n",
+ "\n",
+ "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n",
+ " datas = batch.to_data_list()\n",
+ " new_data = []\n",
+ " for i in range(num_repeat):\n",
+ " new_data += copy.deepcopy(datas)\n",
+ " return Batch.from_data_list(new_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "AMnQTk0eqT7Z"
+ },
+ "source": [
+ "#### Constants"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "WYGkzqgzrHmF"
+ },
+ "outputs": [],
+ "source": [
+ "num_samples = 1 # solutions per molecule\n",
+ "num_molecules = 3\n",
+ "\n",
+ "DEVICE = \"cuda\"\n",
+ "sampling_type = \"ddpm_noisy\" #'' # paper also uses \"generalize\" and \"ld\"\n",
+ "# constants for inference\n",
+ "w_global = 0.5 # 0,.3 for qm9\n",
+ "global_start_sigma = 0.5\n",
+ "eta = 1.0\n",
+ "clip_local = None\n",
+ "clip_pos = None\n",
+ "\n",
+ "# constants for data handling\n",
+ "save_traj = False\n",
+ "save_data = False\n",
+ "output_dir = \"/content/\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-xD5bJ3SqM7t"
+ },
+ "source": [
+ "#### Generate samples!\n",
+ "Note that the 3d representation of a molecule is referred to as the **conformation**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "x9xuLUNg26z1",
+ "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " after removing the cwd from sys.path.\n",
+ "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pickle\n",
+ "\n",
+ "\n",
+ "results = []\n",
+ "\n",
+ "# define sigmas\n",
+ "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n",
+ "sigmas = sigmas.to(DEVICE)\n",
+ "\n",
+ "for count, data in enumerate(tqdm(dataset)):\n",
+ " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n",
+ "\n",
+ " data_input = data.clone()\n",
+ " data_input[\"pos_ref\"] = None\n",
+ " batch = repeat_data(data_input, num_samples).to(DEVICE)\n",
+ "\n",
+ " # initial configuration\n",
+ " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n",
+ "\n",
+ " # for logging animation of denoising\n",
+ " pos_traj = []\n",
+ " with torch.no_grad():\n",
+ " # scale initial sample\n",
+ " pos = pos_init * sigmas[-1]\n",
+ " for t in scheduler.timesteps:\n",
+ " batch.pos = pos\n",
+ "\n",
+ " # generate geometry with model, then filter it\n",
+ " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n",
+ "\n",
+ " # Update\n",
+ " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n",
+ "\n",
+ " pos = reconstructed_pos\n",
+ "\n",
+ " if torch.isnan(pos).any():\n",
+ " print(\"NaN detected. Please restart.\")\n",
+ " raise FloatingPointError()\n",
+ "\n",
+ " # recenter graph of positions for next iteration\n",
+ " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n",
+ "\n",
+ " # optional clipping\n",
+ " if clip_pos is not None:\n",
+ " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n",
+ " pos_traj.append(pos.clone().cpu())\n",
+ "\n",
+ " pos_gen = pos.cpu()\n",
+ " if save_traj:\n",
+ " pos_gen_traj = pos_traj.cpu()\n",
+ " data.pos_gen = torch.stack(pos_gen_traj)\n",
+ " else:\n",
+ " data.pos_gen = pos_gen\n",
+ " results.append(data)\n",
+ "\n",
+ "\n",
+ "if save_data:\n",
+ " save_path = os.path.join(output_dir, \"samples_all.pkl\")\n",
+ "\n",
+ " with open(save_path, \"wb\") as f:\n",
+ " pickle.dump(results, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fSApwSaZNndW"
+ },
+ "source": [
+ "## Render the results!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "d47Zxo2OKdgZ"
+ },
+ "source": [
+ "This function allows us to render 3d in colab."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "e9Cd0kCAv9b8"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import output\n",
+ "\n",
+ "\n",
+ "output.enable_custom_widget_manager()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "RjaVuR15NqzF"
+ },
+ "source": [
+ "### Helper functions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "28rBYa9NKhlz"
+ },
+ "source": [
+ "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "LKdKdwxcyTQ6"
+ },
+ "outputs": [],
+ "source": [
+ "from copy import deepcopy\n",
+ "\n",
+ "\n",
+ "def set_rdmol_positions(rdkit_mol, pos):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
+ " pos: (N_atoms, 3)\n",
+ " \"\"\"\n",
+ " mol = deepcopy(rdkit_mol)\n",
+ " set_rdmol_positions_(mol, pos)\n",
+ " return mol\n",
+ "\n",
+ "\n",
+ "def set_rdmol_positions_(mol, pos):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
+ " pos: (N_atoms, 3)\n",
+ " \"\"\"\n",
+ " for i in range(pos.shape[0]):\n",
+ " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n",
+ " return mol"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NuE10hcpKmzK"
+ },
+ "source": [
+ "Process the generated data to make it easy to view."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "KieVE1vc0_Vs",
+ "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "collect 5 generated molecules in `mols`\n"
+ ]
+ }
+ ],
+ "source": [
+ "# the model can generate multiple conformations per 2d geometry\n",
+ "num_gen = results[0][\"pos_gen\"].shape[0]\n",
+ "\n",
+ "# init storage objects\n",
+ "mols_gen = []\n",
+ "mols_orig = []\n",
+ "for to_process in results:\n",
+ " # store the reference 3d position\n",
+ " to_process[\"pos_ref\"] = to_process[\"pos_ref\"].reshape(-1, to_process[\"rdmol\"].GetNumAtoms(), 3)\n",
+ "\n",
+ " # store the generated 3d position\n",
+ " to_process[\"pos_gen\"] = to_process[\"pos_gen\"].reshape(-1, to_process[\"rdmol\"].GetNumAtoms(), 3)\n",
+ "\n",
+ " # copy data to new object\n",
+ " new_mol = set_rdmol_positions(to_process.rdmol, to_process[\"pos_gen\"][0])\n",
+ "\n",
+ " # append results\n",
+ " mols_gen.append(new_mol)\n",
+ " mols_orig.append(to_process.rdmol)\n",
+ "\n",
+ "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tin89JwMKp4v"
+ },
+ "source": [
+ "Import tools to visualize the 2d chemical diagram of the molecule."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yqV6gllSZn38"
+ },
+ "outputs": [],
+ "source": [
+ "from IPython.display import SVG, display\n",
+ "from rdkit import Chem\n",
+ "from rdkit.Chem.Draw import rdMolDraw2D as MD2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TFNKmGddVoOk"
+ },
+ "source": [
+ "Select molecule to visualize"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KzuwLlrrVaGc"
+ },
+ "outputs": [],
+ "source": [
+ "idx = 0\n",
+ "assert idx < len(results), \"selected molecule that was not generated\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hkb8w0_SNtU8"
+ },
+ "source": [
+ "### Viewing"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "I3R4QBQeKttN"
+ },
+ "source": [
+ "This 2D rendering is the equivalent of the **input to the model**!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 321
+ },
+ "id": "gkQRWjraaKex",
+ "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47"
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": [
+ "\n",
+ "\n",
+ " \n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " "
],
- "source": [
- "import torch\n",
- "print(torch.cuda.is_available())\n",
- "torch.__version__"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KLE7CqlfJNUO"
- },
- "source": [
- "### Install Chemistry-specific Dependencies\n",
- "\n",
- "Install RDKit, a tool for working with and visualizing chemsitry in python (you use this to visualize the generate models later)."
+ "text/plain": [
+ ""
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "0CPv_NvehRz3",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "6ee0ae4e-4511-4816-de29-22b1c21d49bc"
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "mc = Chem.MolFromSmiles(dataset[0][\"smiles\"])\n",
+ "molSize = (450, 300)\n",
+ "drawer = MD2.MolDraw2DSVG(molSize[0], molSize[1])\n",
+ "drawer.DrawMolecule(mc)\n",
+ "drawer.FinishDrawing()\n",
+ "svg = drawer.GetDrawingText()\n",
+ "display(SVG(svg.replace(\"svg:\", \"\")))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "z4FDMYMxKw2I"
+ },
+ "source": [
+ "Generate the 3d molecule!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 17,
+ "referenced_widgets": [
+ "695ab5bbf30a4ab19df1f9f33469f314",
+ "eac6a8dcdc9d4335a2e51031793ead29"
+ ]
+ },
+ "id": "aT1Bkb8YxJfV",
+ "outputId": "b98870ae-049d-4386-b676-166e9526bda2"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "695ab5bbf30a4ab19df1f9f33469f314",
+ "version_major": 2,
+ "version_minor": 0
},
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Collecting rdkit\n",
- " Downloading rdkit-2022.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (36.8 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m36.8/36.8 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: Pillow in /usr/local/lib/python3.7/site-packages (from rdkit) (9.2.0)\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from rdkit) (1.21.6)\n",
- "Installing collected packages: rdkit\n",
- "Successfully installed rdkit-2022.3.5\n",
- "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
+ "text/plain": []
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
}
- ],
- "source": [
- "!pip install rdkit"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "88GaDbDPxJ5I"
+ }
+ }
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from nglview import show_rdkit as show"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 337,
+ "referenced_widgets": [
+ "be446195da2b4ff2aec21ec5ff963a54",
+ "c6596896148b4a8a9c57963b67c7782f",
+ "2489b5e5648541fbbdceadb05632a050",
+ "01e0ba4e5da04914b4652b8d58565d7b",
+ "c30e6c2f3e2a44dbbb3d63bd519acaa4",
+ "f31c6e40e9b2466a9064a2669933ecd5",
+ "19308ccac642498ab8b58462e3f1b0bb",
+ "4a081cdc2ec3421ca79dd933b7e2b0c4",
+ "e5c0d75eb5e1447abd560c8f2c6017e1",
+ "5146907ef6764654ad7d598baebc8b58",
+ "144ec959b7604a2cabb5ca46ae5e5379",
+ "abce2a80e6304df3899109c6d6cac199",
+ "65195cb7a4134f4887e9dd19f3676462"
+ ]
+ },
+ "id": "pxtq8I-I18C-",
+ "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "be446195da2b4ff2aec21ec5ff963a54",
+ "version_major": 2,
+ "version_minor": 0
},
- "source": [
- "### Get viewer from nglview\n",
- "\n",
- "The model you will use outputs a position matrix tensor. This pytorch geometric data object will have many features (positions, known features, edge features -- all tensors).\n",
- "The data we give to the model will also have a rdmol object (which can extract features to geometric if needed).\n",
- "The rdmol in this object is a source of ground truth for the generated molecules.\n",
- "\n",
- "You will use one rendering function from nglviewer later!\n",
- "\n"
+ "text/plain": [
+ "NGLWidget()"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jcl8GCS2mz6t",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 1000
- },
- "outputId": "99b5cc40-67bb-4d8e-faa0-47d7cb33e98f"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
- "Collecting nglview\n",
- " Downloading nglview-3.0.3.tar.gz (5.7 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
- " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.7/site-packages (from nglview) (1.21.6)\n",
- "Collecting jupyterlab-widgets\n",
- " Downloading jupyterlab_widgets-3.0.3-py3-none-any.whl (384 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m384.1/384.1 kB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ipywidgets>=7\n",
- " Downloading ipywidgets-8.0.2-py3-none-any.whl (134 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.4/134.4 kB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting widgetsnbextension~=4.0\n",
- " Downloading widgetsnbextension-4.0.3-py3-none-any.whl (2.0 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m84.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ipython>=6.1.0\n",
- " Downloading ipython-7.34.0-py3-none-any.whl (793 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m793.8/793.8 kB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ipykernel>=4.5.1\n",
- " Downloading ipykernel-6.16.0-py3-none-any.whl (138 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.4/138.4 kB\u001b[0m \u001b[31m20.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting traitlets>=4.3.1\n",
- " Downloading traitlets-5.4.0-py3-none-any.whl (107 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m107.1/107.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.7/site-packages (from ipykernel>=4.5.1->ipywidgets>=7->nglview) (21.3)\n",
- "Collecting pyzmq>=17\n",
- " Downloading pyzmq-24.0.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.1 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m68.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting matplotlib-inline>=0.1\n",
- " Downloading matplotlib_inline-0.1.6-py3-none-any.whl (9.4 kB)\n",
- "Collecting tornado>=6.1\n",
- " Downloading tornado-6.2-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.0/424.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting nest-asyncio\n",
- " Downloading nest_asyncio-1.5.6-py3-none-any.whl (5.2 kB)\n",
- "Collecting debugpy>=1.0\n",
- " Downloading debugpy-1.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.8 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m83.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting psutil\n",
- " Downloading psutil-5.9.2-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (281 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m281.3/281.3 kB\u001b[0m \u001b[31m33.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting jupyter-client>=6.1.12\n",
- " Downloading jupyter_client-7.4.2-py3-none-any.whl (132 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m132.2/132.2 kB\u001b[0m \u001b[31m19.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting pickleshare\n",
- " Downloading pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)\n",
- "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (59.8.0)\n",
- "Collecting backcall\n",
- " Downloading backcall-0.2.0-py2.py3-none-any.whl (11 kB)\n",
- "Collecting pexpect>4.3\n",
- " Downloading pexpect-4.8.0-py2.py3-none-any.whl (59 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.0/59.0 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting pygments\n",
- " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m70.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting jedi>=0.16\n",
- " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m83.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0\n",
- " Downloading prompt_toolkit-3.0.31-py3-none-any.whl (382 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m382.3/382.3 kB\u001b[0m \u001b[31m40.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: decorator in /usr/local/lib/python3.7/site-packages (from ipython>=6.1.0->ipywidgets>=7->nglview) (4.4.2)\n",
- "Collecting parso<0.9.0,>=0.8.0\n",
- " Downloading parso-0.8.3-py2.py3-none-any.whl (100 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m100.8/100.8 kB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.7/site-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (2.8.2)\n",
- "Collecting entrypoints\n",
- " Downloading entrypoints-0.4-py3-none-any.whl (5.3 kB)\n",
- "Collecting jupyter-core>=4.9.2\n",
- " Downloading jupyter_core-4.11.1-py3-none-any.whl (88 kB)\n",
- "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m88.4/88.4 kB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
- "\u001b[?25hCollecting ptyprocess>=0.5\n",
- " Downloading ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)\n",
- "Collecting wcwidth\n",
- " Downloading wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)\n",
- "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/site-packages (from packaging->ipykernel>=4.5.1->ipywidgets>=7->nglview) (3.0.9)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets>=7->nglview) (1.16.0)\n",
- "Building wheels for collected packages: nglview\n",
- " Building wheel for nglview (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- " Created wheel for nglview: filename=nglview-3.0.3-py3-none-any.whl size=8057538 sha256=b7e1071bb91822e48515bf27f4e6b197c6e85e06b90912b3439edc8be1e29514\n",
- " Stored in directory: /root/.cache/pip/wheels/01/0c/49/c6f79d8edba8fe89752bf20de2d99040bfa57db0548975c5d5\n",
- "Successfully built nglview\n",
- "Installing collected packages: wcwidth, ptyprocess, pickleshare, backcall, widgetsnbextension, traitlets, tornado, pyzmq, pygments, psutil, prompt-toolkit, pexpect, parso, nest-asyncio, jupyterlab-widgets, entrypoints, debugpy, matplotlib-inline, jupyter-core, jedi, jupyter-client, ipython, ipykernel, ipywidgets, nglview\n",
- "Successfully installed backcall-0.2.0 debugpy-1.6.3 entrypoints-0.4 ipykernel-6.16.0 ipython-7.34.0 ipywidgets-8.0.2 jedi-0.18.1 jupyter-client-7.4.2 jupyter-core-4.11.1 jupyterlab-widgets-3.0.3 matplotlib-inline-0.1.6 nest-asyncio-1.5.6 nglview-3.0.3 parso-0.8.3 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.31 psutil-5.9.2 ptyprocess-0.7.0 pygments-2.13.0 pyzmq-24.0.1 tornado-6.2 traitlets-5.4.0 wcwidth-0.2.5 widgetsnbextension-4.0.3\n",
- "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
- "\u001b[0m"
- ]
- },
- {
- "output_type": "display_data",
- "data": {
- "application/vnd.colab-display-data+json": {
- "pip_warning": {
- "packages": [
- "pexpect",
- "pickleshare",
- "wcwidth"
- ]
- }
- }
- },
- "metadata": {}
+ },
+ "metadata": {
+ "application/vnd.jupyter.widget-view+json": {
+ "colab": {
+ "custom_widget_manager": {
+ "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
}
- ],
- "source": [
- "!pip install nglview"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Create a diffusion model"
- ],
- "metadata": {
- "id": "8t8_e_uVLdKB"
+ }
}
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Model class(es)"
- ],
- "metadata": {
- "id": "G0rMncVtNSqU"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Imports"
- ],
- "metadata": {
- "id": "L5FEXz5oXkzt"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "# Model adapted from GeoDiff https://github.com/MinkaiXu/GeoDiff\n",
- "# Model inspired by https://github.com/DeepGraphLearning/torchdrug/tree/master/torchdrug/models\n",
- "from dataclasses import dataclass\n",
- "from typing import Callable, Tuple, Union\n",
- "\n",
- "import numpy as np\n",
- "import torch\n",
- "import torch.nn.functional as F\n",
- "from torch import Tensor, nn\n",
- "from torch.nn import Embedding, Linear, Module, ModuleList, Sequential\n",
- "\n",
- "from torch_geometric.nn import MessagePassing, radius, radius_graph\n",
- "from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n",
- "from torch_geometric.utils import dense_to_sparse, to_dense_adj\n",
- "from torch_scatter import scatter_add\n",
- "from torch_sparse import SparseTensor, coalesce\n",
- "\n",
- "from diffusers.configuration_utils import ConfigMixin, register_to_config\n",
- "from diffusers.modeling_utils import ModelMixin\n",
- "from diffusers.utils import BaseOutput\n"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# new molecule\n",
+ "show(mols_gen[idx])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "KJr4h2mwXeTo"
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "01e0ba4e5da04914b4652b8d58565d7b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1",
+ "IPY_MODEL_5146907ef6764654ad7d598baebc8b58"
],
- "metadata": {
- "id": "-3-P4w5sXkRU"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "Helper classes"
+ "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379"
+ }
+ },
+ "144ec959b7604a2cabb5ca46ae5e5379": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "19308ccac642498ab8b58462e3f1b0bb": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1c6246f15b654f4daa11c9bcf997b78c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df",
+ "placeholder": "",
+ "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c",
+ "value": "Downloading: 100%"
+ }
+ },
+ "2489b5e5648541fbbdceadb05632a050": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ButtonModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ButtonView",
+ "button_style": "",
+ "description": "",
+ "disabled": false,
+ "icon": "compress",
+ "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199",
+ "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462",
+ "tooltip": ""
+ }
+ },
+ "24d31fc3576e43dd9f8301d2ef3a37ab": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2918bfaadc8d4b1a9832522c40dfefb8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2c9362906e4b40189f16d14aa9a348da": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "42f6c3db29d7484ba6b4f73590abd2f4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4a081cdc2ec3421ca79dd933b7e2b0c4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "SliderStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "SliderStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": "",
+ "handle_color": null
+ }
+ },
+ "5146907ef6764654ad7d598baebc8b58": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "IntSliderModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "IntSliderModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "IntSliderView",
+ "continuous_update": true,
+ "description": "",
+ "description_tooltip": null,
+ "disabled": false,
+ "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb",
+ "max": 0,
+ "min": 0,
+ "orientation": "horizontal",
+ "readout": true,
+ "readout_format": "d",
+ "step": 1,
+ "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4",
+ "value": 0
+ }
+ },
+ "561f742d418d4721b0670cc8dd62e22c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "6010fc8daa7a44d5aec4b830ec2ebaa1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2",
+ "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
+ "IPY_MODEL_6526646be5ed415c84d1245b040e629b"
],
- "metadata": {
- "id": "EzJQXPN_XrMX"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "@dataclass\n",
- "class MoleculeGNNOutput(BaseOutput):\n",
- " \"\"\"\n",
- " Args:\n",
- " sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):\n",
- " Hidden states output. Output of last layer of model.\n",
- " \"\"\"\n",
- "\n",
- " sample: torch.Tensor\n",
- "\n",
- "\n",
- "class MultiLayerPerceptron(nn.Module):\n",
- " \"\"\"\n",
- " Multi-layer Perceptron. Note there is no activation or dropout in the last layer.\n",
- " Args:\n",
- " input_dim (int): input dimension\n",
- " hidden_dim (list of int): hidden dimensions\n",
- " activation (str or function, optional): activation function\n",
- " dropout (float, optional): dropout rate\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self, input_dim, hidden_dims, activation=\"relu\", dropout=0):\n",
- " super(MultiLayerPerceptron, self).__init__()\n",
- "\n",
- " self.dims = [input_dim] + hidden_dims\n",
- " if isinstance(activation, str):\n",
- " self.activation = getattr(F, activation)\n",
- " else:\n",
- " print(f\"Warning, activation passed {activation} is not string and ignored\")\n",
- " self.activation = None\n",
- " if dropout > 0:\n",
- " self.dropout = nn.Dropout(dropout)\n",
- " else:\n",
- " self.dropout = None\n",
- "\n",
- " self.layers = nn.ModuleList()\n",
- " for i in range(len(self.dims) - 1):\n",
- " self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1]))\n",
- "\n",
- " def forward(self, x):\n",
- " \"\"\"\"\"\"\n",
- " for i, layer in enumerate(self.layers):\n",
- " x = layer(x)\n",
- " if i < len(self.layers) - 1:\n",
- " if self.activation:\n",
- " x = self.activation(x)\n",
- " if self.dropout:\n",
- " x = self.dropout(x)\n",
- " return x\n",
- "\n",
- "\n",
- "class ShiftedSoftplus(torch.nn.Module):\n",
- " def __init__(self):\n",
- " super(ShiftedSoftplus, self).__init__()\n",
- " self.shift = torch.log(torch.tensor(2.0)).item()\n",
- "\n",
- " def forward(self, x):\n",
- " return F.softplus(x) - self.shift\n",
- "\n",
- "\n",
- "class CFConv(MessagePassing):\n",
- " def __init__(self, in_channels, out_channels, num_filters, mlp, cutoff, smooth):\n",
- " super(CFConv, self).__init__(aggr=\"add\")\n",
- " self.lin1 = Linear(in_channels, num_filters, bias=False)\n",
- " self.lin2 = Linear(num_filters, out_channels)\n",
- " self.nn = mlp\n",
- " self.cutoff = cutoff\n",
- " self.smooth = smooth\n",
- "\n",
- " self.reset_parameters()\n",
- "\n",
- " def reset_parameters(self):\n",
- " torch.nn.init.xavier_uniform_(self.lin1.weight)\n",
- " torch.nn.init.xavier_uniform_(self.lin2.weight)\n",
- " self.lin2.bias.data.fill_(0)\n",
- "\n",
- " def forward(self, x, edge_index, edge_length, edge_attr):\n",
- " if self.smooth:\n",
- " C = 0.5 * (torch.cos(edge_length * np.pi / self.cutoff) + 1.0)\n",
- " C = C * (edge_length <= self.cutoff) * (edge_length >= 0.0) # Modification: cutoff\n",
- " else:\n",
- " C = (edge_length <= self.cutoff).float()\n",
- " W = self.nn(edge_attr) * C.view(-1, 1)\n",
- "\n",
- " x = self.lin1(x)\n",
- " x = self.propagate(edge_index, x=x, W=W)\n",
- " x = self.lin2(x)\n",
- " return x\n",
- "\n",
- " def message(self, x_j: torch.Tensor, W) -> torch.Tensor:\n",
- " return x_j * W\n",
- "\n",
- "\n",
- "class InteractionBlock(torch.nn.Module):\n",
- " def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff, smooth):\n",
- " super(InteractionBlock, self).__init__()\n",
- " mlp = Sequential(\n",
- " Linear(num_gaussians, num_filters),\n",
- " ShiftedSoftplus(),\n",
- " Linear(num_filters, num_filters),\n",
- " )\n",
- " self.conv = CFConv(hidden_channels, hidden_channels, num_filters, mlp, cutoff, smooth)\n",
- " self.act = ShiftedSoftplus()\n",
- " self.lin = Linear(hidden_channels, hidden_channels)\n",
- "\n",
- " def forward(self, x, edge_index, edge_length, edge_attr):\n",
- " x = self.conv(x, edge_index, edge_length, edge_attr)\n",
- " x = self.act(x)\n",
- " x = self.lin(x)\n",
- " return x\n",
- "\n",
- "\n",
- "class SchNetEncoder(Module):\n",
- " def __init__(\n",
- " self, hidden_channels=128, num_filters=128, num_interactions=6, edge_channels=100, cutoff=10.0, smooth=False\n",
- " ):\n",
- " super().__init__()\n",
- "\n",
- " self.hidden_channels = hidden_channels\n",
- " self.num_filters = num_filters\n",
- " self.num_interactions = num_interactions\n",
- " self.cutoff = cutoff\n",
- "\n",
- " self.embedding = Embedding(100, hidden_channels, max_norm=10.0)\n",
- "\n",
- " self.interactions = ModuleList()\n",
- " for _ in range(num_interactions):\n",
- " block = InteractionBlock(hidden_channels, edge_channels, num_filters, cutoff, smooth)\n",
- " self.interactions.append(block)\n",
- "\n",
- " def forward(self, z, edge_index, edge_length, edge_attr, embed_node=True):\n",
- " if embed_node:\n",
- " assert z.dim() == 1 and z.dtype == torch.long\n",
- " h = self.embedding(z)\n",
- " else:\n",
- " h = z\n",
- " for interaction in self.interactions:\n",
- " h = h + interaction(h, edge_index, edge_length, edge_attr)\n",
- "\n",
- " return h\n",
- "\n",
- "\n",
- "class GINEConv(MessagePassing):\n",
- " \"\"\"\n",
- " Custom class of the graph isomorphism operator from the \"How Powerful are Graph Neural Networks?\n",
- " https://arxiv.org/abs/1810.00826 paper. Note that this implementation has the added option of a custom activation.\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self, mlp: Callable, eps: float = 0.0, train_eps: bool = False, activation=\"softplus\", **kwargs):\n",
- " super(GINEConv, self).__init__(aggr=\"add\", **kwargs)\n",
- " self.nn = mlp\n",
- " self.initial_eps = eps\n",
- "\n",
- " if isinstance(activation, str):\n",
- " self.activation = getattr(F, activation)\n",
- " else:\n",
- " self.activation = None\n",
- "\n",
- " if train_eps:\n",
- " self.eps = torch.nn.Parameter(torch.Tensor([eps]))\n",
- " else:\n",
- " self.register_buffer(\"eps\", torch.Tensor([eps]))\n",
- "\n",
- " def forward(\n",
- " self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_attr: OptTensor = None, size: Size = None\n",
- " ) -> torch.Tensor:\n",
- " \"\"\"\"\"\"\n",
- " if isinstance(x, torch.Tensor):\n",
- " x: OptPairTensor = (x, x)\n",
- "\n",
- " # Node and edge feature dimensionalites need to match.\n",
- " if isinstance(edge_index, torch.Tensor):\n",
- " assert edge_attr is not None\n",
- " assert x[0].size(-1) == edge_attr.size(-1)\n",
- " elif isinstance(edge_index, SparseTensor):\n",
- " assert x[0].size(-1) == edge_index.size(-1)\n",
- "\n",
- " # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n",
- " out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n",
- "\n",
- " x_r = x[1]\n",
- " if x_r is not None:\n",
- " out += (1 + self.eps) * x_r\n",
- "\n",
- " return self.nn(out)\n",
- "\n",
- " def message(self, x_j: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:\n",
- " if self.activation:\n",
- " return self.activation(x_j + edge_attr)\n",
- " else:\n",
- " return x_j + edge_attr\n",
- "\n",
- " def __repr__(self):\n",
- " return \"{}(nn={})\".format(self.__class__.__name__, self.nn)\n",
- "\n",
- "\n",
- "class GINEncoder(torch.nn.Module):\n",
- " def __init__(self, hidden_dim, num_convs=3, activation=\"relu\", short_cut=True, concat_hidden=False):\n",
- " super().__init__()\n",
- "\n",
- " self.hidden_dim = hidden_dim\n",
- " self.num_convs = num_convs\n",
- " self.short_cut = short_cut\n",
- " self.concat_hidden = concat_hidden\n",
- " self.node_emb = nn.Embedding(100, hidden_dim)\n",
- "\n",
- " if isinstance(activation, str):\n",
- " self.activation = getattr(F, activation)\n",
- " else:\n",
- " self.activation = None\n",
- "\n",
- " self.convs = nn.ModuleList()\n",
- " for i in range(self.num_convs):\n",
- " self.convs.append(\n",
- " GINEConv(\n",
- " MultiLayerPerceptron(hidden_dim, [hidden_dim, hidden_dim], activation=activation),\n",
- " activation=activation,\n",
- " )\n",
- " )\n",
- "\n",
- " def forward(self, z, edge_index, edge_attr):\n",
- " \"\"\"\n",
- " Input:\n",
- " data: (torch_geometric.data.Data): batched graph edge_index: bond indices of the original graph (num_node,\n",
- " hidden) edge_attr: edge feature tensor with shape (num_edge, hidden)\n",
- " Output:\n",
- " node_feature: graph feature\n",
- " \"\"\"\n",
- "\n",
- " node_attr = self.node_emb(z) # (num_node, hidden)\n",
- "\n",
- " hiddens = []\n",
- " conv_input = node_attr # (num_node, hidden)\n",
- "\n",
- " for conv_idx, conv in enumerate(self.convs):\n",
- " hidden = conv(conv_input, edge_index, edge_attr)\n",
- " if conv_idx < len(self.convs) - 1 and self.activation is not None:\n",
- " hidden = self.activation(hidden)\n",
- " assert hidden.shape == conv_input.shape\n",
- " if self.short_cut and hidden.shape == conv_input.shape:\n",
- " hidden += conv_input\n",
- "\n",
- " hiddens.append(hidden)\n",
- " conv_input = hidden\n",
- "\n",
- " if self.concat_hidden:\n",
- " node_feature = torch.cat(hiddens, dim=-1)\n",
- " else:\n",
- " node_feature = hiddens[-1]\n",
- "\n",
- " return node_feature\n",
- "\n",
- "\n",
- "class MLPEdgeEncoder(Module):\n",
- " def __init__(self, hidden_dim=100, activation=\"relu\"):\n",
- " super().__init__()\n",
- " self.hidden_dim = hidden_dim\n",
- " self.bond_emb = Embedding(100, embedding_dim=self.hidden_dim)\n",
- " self.mlp = MultiLayerPerceptron(1, [self.hidden_dim, self.hidden_dim], activation=activation)\n",
- "\n",
- " @property\n",
- " def out_channels(self):\n",
- " return self.hidden_dim\n",
- "\n",
- " def forward(self, edge_length, edge_type):\n",
- " \"\"\"\n",
- " Input:\n",
- " edge_length: The length of edges, shape=(E, 1). edge_type: The type pf edges, shape=(E,)\n",
- " Returns:\n",
- " edge_attr: The representation of edges. (E, 2 * num_gaussians)\n",
- " \"\"\"\n",
- " d_emb = self.mlp(edge_length) # (num_edge, hidden_dim)\n",
- " edge_attr = self.bond_emb(edge_type) # (num_edge, hidden_dim)\n",
- " return d_emb * edge_attr # (num_edge, hidden)\n",
- "\n",
- "\n",
- "def assemble_atom_pair_feature(node_attr, edge_index, edge_attr):\n",
- " h_row, h_col = node_attr[edge_index[0]], node_attr[edge_index[1]]\n",
- " h_pair = torch.cat([h_row * h_col, edge_attr], dim=-1) # (E, 2H)\n",
- " return h_pair\n",
- "\n",
- "\n",
- "def _extend_graph_order(num_nodes, edge_index, edge_type, order=3):\n",
- " \"\"\"\n",
- " Args:\n",
- " num_nodes: Number of atoms.\n",
- " edge_index: Bond indices of the original graph.\n",
- " edge_type: Bond types of the original graph.\n",
- " order: Extension order.\n",
- " Returns:\n",
- " new_edge_index: Extended edge indices. new_edge_type: Extended edge types.\n",
- " \"\"\"\n",
- "\n",
- " def binarize(x):\n",
- " return torch.where(x > 0, torch.ones_like(x), torch.zeros_like(x))\n",
- "\n",
- " def get_higher_order_adj_matrix(adj, order):\n",
- " \"\"\"\n",
- " Args:\n",
- " adj: (N, N)\n",
- " type_mat: (N, N)\n",
- " Returns:\n",
- " Following attributes will be updated:\n",
- " - edge_index\n",
- " - edge_type\n",
- " Following attributes will be added to the data object:\n",
- " - bond_edge_index: Original edge_index.\n",
- " \"\"\"\n",
- " adj_mats = [\n",
- " torch.eye(adj.size(0), dtype=torch.long, device=adj.device),\n",
- " binarize(adj + torch.eye(adj.size(0), dtype=torch.long, device=adj.device)),\n",
- " ]\n",
- "\n",
- " for i in range(2, order + 1):\n",
- " adj_mats.append(binarize(adj_mats[i - 1] @ adj_mats[1]))\n",
- " order_mat = torch.zeros_like(adj)\n",
- "\n",
- " for i in range(1, order + 1):\n",
- " order_mat += (adj_mats[i] - adj_mats[i - 1]) * i\n",
- "\n",
- " return order_mat\n",
- "\n",
- " num_types = 22\n",
- " # given from len(BOND_TYPES), where BOND_TYPES = {t: i for i, t in enumerate(BT.names.values())}\n",
- " # from rdkit.Chem.rdchem import BondType as BT\n",
- " N = num_nodes\n",
- " adj = to_dense_adj(edge_index).squeeze(0)\n",
- " adj_order = get_higher_order_adj_matrix(adj, order) # (N, N)\n",
- "\n",
- " type_mat = to_dense_adj(edge_index, edge_attr=edge_type).squeeze(0) # (N, N)\n",
- " type_highorder = torch.where(adj_order > 1, num_types + adj_order - 1, torch.zeros_like(adj_order))\n",
- " assert (type_mat * type_highorder == 0).all()\n",
- " type_new = type_mat + type_highorder\n",
- "\n",
- " new_edge_index, new_edge_type = dense_to_sparse(type_new)\n",
- " _, edge_order = dense_to_sparse(adj_order)\n",
- "\n",
- " # data.bond_edge_index = data.edge_index # Save original edges\n",
- " new_edge_index, new_edge_type = coalesce(new_edge_index, new_edge_type.long(), N, N) # modify data\n",
- "\n",
- " return new_edge_index, new_edge_type\n",
- "\n",
- "\n",
- "def _extend_to_radius_graph(pos, edge_index, edge_type, cutoff, batch, unspecified_type_number=0, is_sidechain=None):\n",
- " assert edge_type.dim() == 1\n",
- " N = pos.size(0)\n",
- "\n",
- " bgraph_adj = torch.sparse.LongTensor(edge_index, edge_type, torch.Size([N, N]))\n",
- "\n",
- " if is_sidechain is None:\n",
- " rgraph_edge_index = radius_graph(pos, r=cutoff, batch=batch) # (2, E_r)\n",
- " else:\n",
- " # fetch sidechain and its batch index\n",
- " is_sidechain = is_sidechain.bool()\n",
- " dummy_index = torch.arange(pos.size(0), device=pos.device)\n",
- " sidechain_pos = pos[is_sidechain]\n",
- " sidechain_index = dummy_index[is_sidechain]\n",
- " sidechain_batch = batch[is_sidechain]\n",
- "\n",
- " assign_index = radius(x=pos, y=sidechain_pos, r=cutoff, batch_x=batch, batch_y=sidechain_batch)\n",
- " r_edge_index_x = assign_index[1]\n",
- " r_edge_index_y = assign_index[0]\n",
- " r_edge_index_y = sidechain_index[r_edge_index_y]\n",
- "\n",
- " rgraph_edge_index1 = torch.stack((r_edge_index_x, r_edge_index_y)) # (2, E)\n",
- " rgraph_edge_index2 = torch.stack((r_edge_index_y, r_edge_index_x)) # (2, E)\n",
- " rgraph_edge_index = torch.cat((rgraph_edge_index1, rgraph_edge_index2), dim=-1) # (2, 2E)\n",
- " # delete self loop\n",
- " rgraph_edge_index = rgraph_edge_index[:, (rgraph_edge_index[0] != rgraph_edge_index[1])]\n",
- "\n",
- " rgraph_adj = torch.sparse.LongTensor(\n",
- " rgraph_edge_index,\n",
- " torch.ones(rgraph_edge_index.size(1)).long().to(pos.device) * unspecified_type_number,\n",
- " torch.Size([N, N]),\n",
- " )\n",
- "\n",
- " composed_adj = (bgraph_adj + rgraph_adj).coalesce() # Sparse (N, N, T)\n",
- "\n",
- " new_edge_index = composed_adj.indices()\n",
- " new_edge_type = composed_adj.values().long()\n",
- "\n",
- " return new_edge_index, new_edge_type\n",
- "\n",
- "\n",
- "def extend_graph_order_radius(\n",
- " num_nodes,\n",
- " pos,\n",
- " edge_index,\n",
- " edge_type,\n",
- " batch,\n",
- " order=3,\n",
- " cutoff=10.0,\n",
- " extend_order=True,\n",
- " extend_radius=True,\n",
- " is_sidechain=None,\n",
- "):\n",
- " if extend_order:\n",
- " edge_index, edge_type = _extend_graph_order(\n",
- " num_nodes=num_nodes, edge_index=edge_index, edge_type=edge_type, order=order\n",
- " )\n",
- "\n",
- " if extend_radius:\n",
- " edge_index, edge_type = _extend_to_radius_graph(\n",
- " pos=pos, edge_index=edge_index, edge_type=edge_type, cutoff=cutoff, batch=batch, is_sidechain=is_sidechain\n",
- " )\n",
- "\n",
- " return edge_index, edge_type\n",
- "\n",
- "\n",
- "def get_distance(pos, edge_index):\n",
- " return (pos[edge_index[0]] - pos[edge_index[1]]).norm(dim=-1)\n",
- "\n",
- "\n",
- "def graph_field_network(score_d, pos, edge_index, edge_length):\n",
- " \"\"\"\n",
- " Transformation to make the epsilon predicted from the diffusion model roto-translational equivariant. See equations\n",
- " 5-7 of the GeoDiff Paper https://arxiv.org/pdf/2203.02923.pdf\n",
- " \"\"\"\n",
- " N = pos.size(0)\n",
- " dd_dr = (1.0 / edge_length) * (pos[edge_index[0]] - pos[edge_index[1]]) # (E, 3)\n",
- " score_pos = scatter_add(dd_dr * score_d, edge_index[0], dim=0, dim_size=N) + scatter_add(\n",
- " -dd_dr * score_d, edge_index[1], dim=0, dim_size=N\n",
- " ) # (N, 3)\n",
- " return score_pos\n",
- "\n",
- "\n",
- "def clip_norm(vec, limit, p=2):\n",
- " norm = torch.norm(vec, dim=-1, p=2, keepdim=True)\n",
- " denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))\n",
- " return vec * denom\n",
- "\n",
- "\n",
- "def is_local_edge(edge_type):\n",
- " return edge_type > 0\n"
+ "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab"
+ }
+ },
+ "65195cb7a4134f4887e9dd19f3676462": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ButtonStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ButtonStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "button_color": null,
+ "font_weight": ""
+ }
+ },
+ "6526646be5ed415c84d1245b040e629b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433",
+ "placeholder": "",
+ "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4",
+ "value": " 401/401 [00:00<00:00, 13.5kB/s]"
+ }
+ },
+ "695ab5bbf30a4ab19df1f9f33469f314": {
+ "model_module": "nglview-js-widgets",
+ "model_module_version": "3.0.1",
+ "model_name": "ColormakerRegistryModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "nglview-js-widgets",
+ "_model_module_version": "3.0.1",
+ "_model_name": "ColormakerRegistryModel",
+ "_msg_ar": [],
+ "_msg_q": [],
+ "_ready": false,
+ "_view_count": null,
+ "_view_module": "nglview-js-widgets",
+ "_view_module_version": "3.0.1",
+ "_view_name": "ColormakerRegistryView",
+ "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29"
+ }
+ },
+ "7e0bb1b8d65249d3974200686b193be2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8",
+ "placeholder": "",
+ "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08",
+ "value": "Downloading: 100%"
+ }
+ },
+ "872915dd1bb84f538c44e26badabafdd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a4bfdca35cc54dae8812720f1b276a08": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "a9e388f22a9742aaaf538e22575c9433": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "abce2a80e6304df3899109c6d6cac199": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": "34px"
+ }
+ },
+ "b7feb522161f4cf4b7cc7c1a078ff12d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436",
+ "placeholder": "",
+ "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da",
+ "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]"
+ }
+ },
+ "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39",
+ "max": 401,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_f915cf874246446595206221e900b2fe",
+ "value": 401
+ }
+ },
+ "bbef741e76ec41b7ab7187b487a383df": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "be446195da2b4ff2aec21ec5ff963a54": {
+ "model_module": "nglview-js-widgets",
+ "model_module_version": "3.0.1",
+ "model_name": "NGLModel",
+ "state": {
+ "_camera_orientation": [
+ -15.519693580202304,
+ -14.065056548036177,
+ -23.53197484807691,
+ 0,
+ -23.357853515109753,
+ 20.94055073042662,
+ 2.888695042134944,
+ 0,
+ 14.352363398292775,
+ 18.870825741878015,
+ -20.744689572909344,
+ 0,
+ 0.2724999189376831,
+ 0.6940000057220459,
+ -0.3734999895095825,
+ 1
],
- "metadata": {
- "id": "oR1Y56QiLY90"
+ "_camera_str": "orthographic",
+ "_dom_classes": [],
+ "_gui_theme": null,
+ "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050",
+ "_igui": null,
+ "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b",
+ "_model_module": "nglview-js-widgets",
+ "_model_module_version": "3.0.1",
+ "_model_name": "NGLModel",
+ "_ngl_color_dict": {},
+ "_ngl_coordinate_resource": {},
+ "_ngl_full_stage_parameters": {
+ "ambientColor": 14540253,
+ "ambientIntensity": 0.2,
+ "backgroundColor": "white",
+ "cameraEyeSep": 0.3,
+ "cameraFov": 40,
+ "cameraType": "perspective",
+ "clipDist": 10,
+ "clipFar": 100,
+ "clipNear": 0,
+ "fogFar": 100,
+ "fogNear": 50,
+ "hoverTimeout": 0,
+ "impostor": true,
+ "lightColor": 14540253,
+ "lightIntensity": 1,
+ "mousePreset": "default",
+ "panSpeed": 1,
+ "quality": "medium",
+ "rotateSpeed": 2,
+ "sampleLevel": 0,
+ "tooltip": true,
+ "workerDefault": true,
+ "zoomSpeed": 1.2
},
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "Main model class!"
- ],
- "metadata": {
- "id": "QWrHJFcYXyUB"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "class MoleculeGNN(ModelMixin, ConfigMixin):\n",
- " @register_to_config\n",
- " def __init__(\n",
- " self,\n",
- " hidden_dim=128,\n",
- " num_convs=6,\n",
- " num_convs_local=4,\n",
- " cutoff=10.0,\n",
- " mlp_act=\"relu\",\n",
- " edge_order=3,\n",
- " edge_encoder=\"mlp\",\n",
- " smooth_conv=True,\n",
- " ):\n",
- " super().__init__()\n",
- " self.cutoff = cutoff\n",
- " self.edge_encoder = edge_encoder\n",
- " self.edge_order = edge_order\n",
- "\n",
- " \"\"\"\n",
- " edge_encoder: Takes both edge type and edge length as input and outputs a vector [Note]: node embedding is done\n",
- " in SchNetEncoder\n",
- " \"\"\"\n",
- " self.edge_encoder_global = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
- " self.edge_encoder_local = MLPEdgeEncoder(hidden_dim, mlp_act) # get_edge_encoder(config)\n",
- "\n",
- " \"\"\"\n",
- " The graph neural network that extracts node-wise features.\n",
- " \"\"\"\n",
- " self.encoder_global = SchNetEncoder(\n",
- " hidden_channels=hidden_dim,\n",
- " num_filters=hidden_dim,\n",
- " num_interactions=num_convs,\n",
- " edge_channels=self.edge_encoder_global.out_channels,\n",
- " cutoff=cutoff,\n",
- " smooth=smooth_conv,\n",
- " )\n",
- " self.encoder_local = GINEncoder(\n",
- " hidden_dim=hidden_dim,\n",
- " num_convs=num_convs_local,\n",
- " )\n",
- "\n",
- " \"\"\"\n",
- " `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs\n",
- " gradients w.r.t. edge_length (out_dim = 1).\n",
- " \"\"\"\n",
- " self.grad_global_dist_mlp = MultiLayerPerceptron(\n",
- " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
- " )\n",
- "\n",
- " self.grad_local_dist_mlp = MultiLayerPerceptron(\n",
- " 2 * hidden_dim, [hidden_dim, hidden_dim // 2, 1], activation=mlp_act\n",
- " )\n",
- "\n",
- " \"\"\"\n",
- " Incorporate parameters together\n",
- " \"\"\"\n",
- " self.model_global = nn.ModuleList([self.edge_encoder_global, self.encoder_global, self.grad_global_dist_mlp])\n",
- " self.model_local = nn.ModuleList([self.edge_encoder_local, self.encoder_local, self.grad_local_dist_mlp])\n",
- "\n",
- " def _forward(\n",
- " self,\n",
- " atom_type,\n",
- " pos,\n",
- " bond_index,\n",
- " bond_type,\n",
- " batch,\n",
- " time_step, # NOTE, model trained without timestep performed best\n",
- " edge_index=None,\n",
- " edge_type=None,\n",
- " edge_length=None,\n",
- " return_edges=False,\n",
- " extend_order=True,\n",
- " extend_radius=True,\n",
- " is_sidechain=None,\n",
- " ):\n",
- " \"\"\"\n",
- " Args:\n",
- " atom_type: Types of atoms, (N, ).\n",
- " bond_index: Indices of bonds (not extended, not radius-graph), (2, E).\n",
- " bond_type: Bond types, (E, ).\n",
- " batch: Node index to graph index, (N, ).\n",
- " \"\"\"\n",
- " N = atom_type.size(0)\n",
- " if edge_index is None or edge_type is None or edge_length is None:\n",
- " edge_index, edge_type = extend_graph_order_radius(\n",
- " num_nodes=N,\n",
- " pos=pos,\n",
- " edge_index=bond_index,\n",
- " edge_type=bond_type,\n",
- " batch=batch,\n",
- " order=self.edge_order,\n",
- " cutoff=self.cutoff,\n",
- " extend_order=extend_order,\n",
- " extend_radius=extend_radius,\n",
- " is_sidechain=is_sidechain,\n",
- " )\n",
- " edge_length = get_distance(pos, edge_index).unsqueeze(-1) # (E, 1)\n",
- " local_edge_mask = is_local_edge(edge_type) # (E, )\n",
- "\n",
- " # with the parameterization of NCSNv2\n",
- " # DDPM loss implicit handle the noise variance scale conditioning\n",
- " sigma_edge = torch.ones(size=(edge_index.size(1), 1), device=pos.device) # (E, 1)\n",
- "\n",
- " # Encoding global\n",
- " edge_attr_global = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
- "\n",
- " # Global\n",
- " node_attr_global = self.encoder_global(\n",
- " z=atom_type,\n",
- " edge_index=edge_index,\n",
- " edge_length=edge_length,\n",
- " edge_attr=edge_attr_global,\n",
- " )\n",
- " # Assemble pairwise features\n",
- " h_pair_global = assemble_atom_pair_feature(\n",
- " node_attr=node_attr_global,\n",
- " edge_index=edge_index,\n",
- " edge_attr=edge_attr_global,\n",
- " ) # (E_global, 2H)\n",
- " # Invariant features of edges (radius graph, global)\n",
- " edge_inv_global = self.grad_global_dist_mlp(h_pair_global) * (1.0 / sigma_edge) # (E_global, 1)\n",
- "\n",
- " # Encoding local\n",
- " edge_attr_local = self.edge_encoder_global(edge_length=edge_length, edge_type=edge_type) # Embed edges\n",
- " # edge_attr += temb_edge\n",
- "\n",
- " # Local\n",
- " node_attr_local = self.encoder_local(\n",
- " z=atom_type,\n",
- " edge_index=edge_index[:, local_edge_mask],\n",
- " edge_attr=edge_attr_local[local_edge_mask],\n",
- " )\n",
- " # Assemble pairwise features\n",
- " h_pair_local = assemble_atom_pair_feature(\n",
- " node_attr=node_attr_local,\n",
- " edge_index=edge_index[:, local_edge_mask],\n",
- " edge_attr=edge_attr_local[local_edge_mask],\n",
- " ) # (E_local, 2H)\n",
- "\n",
- " # Invariant features of edges (bond graph, local)\n",
- " if isinstance(sigma_edge, torch.Tensor):\n",
- " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (\n",
- " 1.0 / sigma_edge[local_edge_mask]\n",
- " ) # (E_local, 1)\n",
- " else:\n",
- " edge_inv_local = self.grad_local_dist_mlp(h_pair_local) * (1.0 / sigma_edge) # (E_local, 1)\n",
- "\n",
- " if return_edges:\n",
- " return edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask\n",
- " else:\n",
- " return edge_inv_global, edge_inv_local\n",
- "\n",
- " def forward(\n",
- " self,\n",
- " sample,\n",
- " timestep: Union[torch.Tensor, float, int],\n",
- " return_dict: bool = True,\n",
- " sigma=1.0,\n",
- " global_start_sigma=0.5,\n",
- " w_global=1.0,\n",
- " extend_order=False,\n",
- " extend_radius=True,\n",
- " clip_local=None,\n",
- " clip_global=1000.0,\n",
- " ) -> Union[MoleculeGNNOutput, Tuple]:\n",
- " r\"\"\"\n",
- " Args:\n",
- " sample: packed torch geometric object\n",
- " timestep (`torch.Tensor` or `float` or `int): TODO verify type and shape (batch) timesteps\n",
- " return_dict (`bool`, *optional*, defaults to `True`):\n",
- " Whether or not to return a [`~models.molecule_gnn.MoleculeGNNOutput`] instead of a plain tuple.\n",
- " Returns:\n",
- " [`~models.molecule_gnn.MoleculeGNNOutput`] or `tuple`: [`~models.molecule_gnn.MoleculeGNNOutput`] if\n",
- " `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.\n",
- " \"\"\"\n",
- "\n",
- " # unpack sample\n",
- " atom_type = sample.atom_type\n",
- " bond_index = sample.edge_index\n",
- " bond_type = sample.edge_type\n",
- " num_graphs = sample.num_graphs\n",
- " pos = sample.pos\n",
- "\n",
- " timesteps = torch.full(size=(num_graphs,), fill_value=timestep, dtype=torch.long, device=pos.device)\n",
- "\n",
- " edge_inv_global, edge_inv_local, edge_index, edge_type, edge_length, local_edge_mask = self._forward(\n",
- " atom_type=atom_type,\n",
- " pos=sample.pos,\n",
- " bond_index=bond_index,\n",
- " bond_type=bond_type,\n",
- " batch=sample.batch,\n",
- " time_step=timesteps,\n",
- " return_edges=True,\n",
- " extend_order=extend_order,\n",
- " extend_radius=extend_radius,\n",
- " ) # (E_global, 1), (E_local, 1)\n",
- "\n",
- " # Important equation in the paper for equivariant features - eqns 5-7 of GeoDiff\n",
- " node_eq_local = graph_field_network(\n",
- " edge_inv_local, pos, edge_index[:, local_edge_mask], edge_length[local_edge_mask]\n",
- " )\n",
- " if clip_local is not None:\n",
- " node_eq_local = clip_norm(node_eq_local, limit=clip_local)\n",
- "\n",
- " # Global\n",
- " if sigma < global_start_sigma:\n",
- " edge_inv_global = edge_inv_global * (1 - local_edge_mask.view(-1, 1).float())\n",
- " node_eq_global = graph_field_network(edge_inv_global, pos, edge_index, edge_length)\n",
- " node_eq_global = clip_norm(node_eq_global, limit=clip_global)\n",
- " else:\n",
- " node_eq_global = 0\n",
- "\n",
- " # Sum\n",
- " eps_pos = node_eq_local + node_eq_global * w_global\n",
- "\n",
- " if not return_dict:\n",
- " return (-eps_pos,)\n",
- "\n",
- " return MoleculeGNNOutput(sample=torch.Tensor(-eps_pos).to(pos.device))"
- ],
- "metadata": {
- "id": "MCeZA1qQXzoK"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "CCIrPYSJj9wd"
- },
- "source": [
- "### Load pretrained model"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "YdrAr6Ch--Ab"
- },
- "source": [
- "#### Load a model\n",
- "The model used is a design an\n",
- "equivariant convolutional layer, named graph field network (GFN).\n",
- "\n",
- "The warning about `betas` and `alphas` can be ignored, those were moved to the scheduler."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "DyCo0nsqjbml",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 172,
- "referenced_widgets": [
- "d90f304e9560472eacfbdd11e46765eb",
- "1c6246f15b654f4daa11c9bcf997b78c",
- "c2321b3bff6f490ca12040a20308f555",
- "b7feb522161f4cf4b7cc7c1a078ff12d",
- "e2d368556e494ae7ae4e2e992af2cd4f",
- "bbef741e76ec41b7ab7187b487a383df",
- "561f742d418d4721b0670cc8dd62e22c",
- "872915dd1bb84f538c44e26badabafdd",
- "d022575f1fa2446d891650897f187b4d",
- "fdc393f3468c432aa0ada05e238a5436",
- "2c9362906e4b40189f16d14aa9a348da",
- "6010fc8daa7a44d5aec4b830ec2ebaa1",
- "7e0bb1b8d65249d3974200686b193be2",
- "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
- "6526646be5ed415c84d1245b040e629b",
- "24d31fc3576e43dd9f8301d2ef3a37ab",
- "2918bfaadc8d4b1a9832522c40dfefb8",
- "a4bfdca35cc54dae8812720f1b276a08",
- "e4901541199b45c6a18824627692fc39",
- "f915cf874246446595206221e900b2fe",
- "a9e388f22a9742aaaf538e22575c9433",
- "42f6c3db29d7484ba6b4f73590abd2f4"
- ]
+ "_ngl_msg_archive": [
+ {
+ "args": [
+ {
+ "binary": false,
+ "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n",
+ "type": "blob"
+ }
+ ],
+ "kwargs": {
+ "defaultRepresentation": true,
+ "ext": "pdb"
},
- "outputId": "d6bce9d5-c51e-43a4-e680-e1e81bdfaf45"
+ "methodName": "loadFile",
+ "reconstruc_color_scheme": false,
+ "target": "Stage",
+ "type": "call_method"
+ }
+ ],
+ "_ngl_original_stage_parameters": {
+ "ambientColor": 14540253,
+ "ambientIntensity": 0.2,
+ "backgroundColor": "white",
+ "cameraEyeSep": 0.3,
+ "cameraFov": 40,
+ "cameraType": "perspective",
+ "clipDist": 10,
+ "clipFar": 100,
+ "clipNear": 0,
+ "fogFar": 100,
+ "fogNear": 50,
+ "hoverTimeout": 0,
+ "impostor": true,
+ "lightColor": 14540253,
+ "lightIntensity": 1,
+ "mousePreset": "default",
+ "panSpeed": 1,
+ "quality": "medium",
+ "rotateSpeed": 2,
+ "sampleLevel": 0,
+ "tooltip": true,
+ "workerDefault": true,
+ "zoomSpeed": 1.2
},
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "Downloading: 0%| | 0.00/3.27M [00:00, ?B/s]"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "d90f304e9560472eacfbdd11e46765eb"
- }
+ "_ngl_repr_dict": {
+ "0": {
+ "0": {
+ "params": {
+ "aspectRatio": 1.5,
+ "assembly": "default",
+ "bondScale": 0.3,
+ "bondSpacing": 0.75,
+ "clipCenter": {
+ "x": 0,
+ "y": 0,
+ "z": 0
},
- "metadata": {}
- },
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "Downloading: 0%| | 0.00/401 [00:00, ?B/s]"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "6010fc8daa7a44d5aec4b830ec2ebaa1"
- }
+ "clipNear": 0,
+ "clipRadius": 0,
+ "colorMode": "hcl",
+ "colorReverse": false,
+ "colorScale": "",
+ "colorScheme": "element",
+ "colorValue": 9474192,
+ "cylinderOnly": false,
+ "defaultAssembly": "",
+ "depthWrite": true,
+ "diffuse": 16777215,
+ "diffuseInterior": false,
+ "disableImpostor": false,
+ "disablePicking": false,
+ "flatShaded": false,
+ "interiorColor": 2236962,
+ "interiorDarkening": 0,
+ "lazy": false,
+ "lineOnly": false,
+ "linewidth": 2,
+ "matrix": {
+ "elements": [
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1
+ ]
},
- "metadata": {}
- },
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "The config attributes {'type': 'diffusion', 'network': 'dualenc', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'num_diffusion_timesteps': 5000} were passed to MoleculeGNN, but are not expected and will be ignored. Please verify your config.json configuration file.\n",
- "Some weights of the model checkpoint at fusing/gfn-molecule-gen-drugs were not used when initializing MoleculeGNN: ['betas', 'alphas']\n",
- "- This IS expected if you are initializing MoleculeGNN from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
- "- This IS NOT expected if you are initializing MoleculeGNN from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
- ]
+ "metalness": 0,
+ "multipleBond": "off",
+ "opacity": 1,
+ "openEnded": true,
+ "quality": "high",
+ "radialSegments": 20,
+ "radiusData": {},
+ "radiusScale": 2,
+ "radiusSize": 0.15,
+ "radiusType": "size",
+ "roughness": 0.4,
+ "sele": "",
+ "side": "double",
+ "sphereDetail": 2,
+ "useInteriorColor": true,
+ "visible": true,
+ "wireframe": false
+ },
+ "type": "ball+stick"
}
- ],
- "source": [
- "DEVICE = 'cuda'\n",
- "model = MoleculeGNN.from_pretrained(\"fusing/gfn-molecule-gen-drugs\").to(DEVICE)"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "The warnings above are because the pre-trained model was uploaded before cleaning the code!"
- ],
- "metadata": {
- "id": "HdclRaqoUWUD"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "PlOkPySoJ1m9"
- },
- "source": [
- "#### Create scheduler\n",
- "Note, other schedulers are used in the paper for slightly improved performance over DDPM."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "nNHnIk9CkAb2"
- },
- "outputs": [],
- "source": [
- "from diffusers import DDPMScheduler"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "RnDJdDBztjFF"
- },
- "outputs": [],
- "source": [
- "num_timesteps = 1000\n",
- "scheduler = DDPMScheduler(num_train_timesteps=num_timesteps,beta_schedule=\"sigmoid\",beta_start=1e-7, beta_end=2e-3, clip_sample=False)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "1vh3fpSAflkL"
- },
- "source": [
- "### Get a dataset"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "B6qzaGjVKFVk"
- },
- "source": [
- "Grab a google tool so we can upload our data directly. Note you need to download the data from ***this [file](https://huggingface.co/datasets/fusing/geodiff-example-data/blob/main/data/molecules.pkl)***\n",
- "\n",
- "(direct downloading from the hub does not yet work for this datatype)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jbLl3EJdgj3x"
- },
- "outputs": [],
- "source": [
- "# from google.colab import files"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "E591lVuTgxPE"
- },
- "outputs": [],
- "source": [
- "# uploaded = files.upload()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KUNxfK3ln98Q"
- },
- "source": [
- "Load the dataset with torch."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "7L4iOShTpcQX",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "7f2dcd29-493e-44de-98d1-3ad50f109a4a"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "--2022-10-12 18:32:19-- https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
- "Resolving huggingface.co (huggingface.co)... 44.195.102.200, 52.5.54.249, 54.210.225.113, ...\n",
- "Connecting to huggingface.co (huggingface.co)|44.195.102.200|:443... connected.\n",
- "HTTP request sent, awaiting response... 200 OK\n",
- "Length: 127774 (125K) [application/octet-stream]\n",
- "Saving to: ‘molecules.pkl’\n",
- "\n",
- "molecules.pkl 100%[===================>] 124.78K 180KB/s in 0.7s \n",
- "\n",
- "2022-10-12 18:32:20 (180 KB/s) - ‘molecules.pkl’ saved [127774/127774]\n",
- "\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "import numpy as np\n",
- "\n",
- "!wget https://huggingface.co/datasets/fusing/geodiff-example-data/resolve/main/data/molecules.pkl\n",
- "dataset = torch.load('/content/molecules.pkl')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "QZcmy1EvKQRk"
- },
- "source": [
- "Print out one entry of the dataset, it contains molecular formulas, atom types, positions, and more."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "JVjz6iH_H6Eh",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "898cb0cf-a0b3-411b-fd4c-bea1fbfd17fe"
- },
- "outputs": [
- {
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "Data(atom_type=[51], bond_edge_index=[2, 108], edge_index=[2, 598], edge_order=[598], edge_type=[598], idx=[1], is_bond=[598], num_nodes_per_graph=[1], num_pos_ref=[1], nx=, pos=[51, 3], pos_ref=[255, 3], rdmol=, smiles=\"CC1CCCN(C(=O)C2CCN(S(=O)(=O)c3cccc4nonc34)CC2)C1\")"
- ]
+ },
+ "1": {
+ "0": {
+ "params": {
+ "aspectRatio": 1.5,
+ "assembly": "default",
+ "bondScale": 0.3,
+ "bondSpacing": 0.75,
+ "clipCenter": {
+ "x": 0,
+ "y": 0,
+ "z": 0
},
- "metadata": {},
- "execution_count": 20
- }
- ],
- "source": [
- "dataset[0]"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Run the diffusion process"
- ],
- "metadata": {
- "id": "vHNiZAUxNgoy"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "jZ1KZrxKqENg"
- },
- "source": [
- "#### Helper Functions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "s240tYueqKKf"
- },
- "outputs": [],
- "source": [
- "from torch_geometric.data import Data, Batch\n",
- "from torch_scatter import scatter_add, scatter_mean\n",
- "from tqdm import tqdm\n",
- "import copy\n",
- "import os\n",
- "\n",
- "def repeat_data(data: Data, num_repeat) -> Batch:\n",
- " datas = [copy.deepcopy(data) for i in range(num_repeat)]\n",
- " return Batch.from_data_list(datas)\n",
- "\n",
- "def repeat_batch(batch: Batch, num_repeat) -> Batch:\n",
- " datas = batch.to_data_list()\n",
- " new_data = []\n",
- " for i in range(num_repeat):\n",
- " new_data += copy.deepcopy(datas)\n",
- " return Batch.from_data_list(new_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "AMnQTk0eqT7Z"
- },
- "source": [
- "#### Constants"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "WYGkzqgzrHmF"
- },
- "outputs": [],
- "source": [
- "num_samples = 1 # solutions per molecule\n",
- "num_molecules = 3\n",
- "\n",
- "DEVICE = 'cuda'\n",
- "sampling_type = 'ddpm_noisy' #'' # paper also uses \"generalize\" and \"ld\"\n",
- "# constants for inference\n",
- "w_global = 0.5 #0,.3 for qm9\n",
- "global_start_sigma = 0.5\n",
- "eta = 1.0\n",
- "clip_local = None\n",
- "clip_pos = None\n",
- "\n",
- "# constands for data handling\n",
- "save_traj = False\n",
- "save_data = False\n",
- "output_dir = '/content/'"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-xD5bJ3SqM7t"
- },
- "source": [
- "#### Generate samples!\n",
- "Note that the 3d representation of a molecule is referred to as the **conformation**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "x9xuLUNg26z1",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "236d2a60-09ed-4c4d-97c1-6e3c0f2d26c4"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stderr",
- "text": [
- "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
- " after removing the cwd from sys.path.\n",
- "100%|██████████| 5/5 [00:55<00:00, 11.06s/it]\n"
- ]
- }
- ],
- "source": [
- "results = []\n",
- "\n",
- "# define sigmas\n",
- "sigmas = torch.tensor(1.0 - scheduler.alphas_cumprod).sqrt() / torch.tensor(scheduler.alphas_cumprod).sqrt()\n",
- "sigmas = sigmas.to(DEVICE)\n",
- "\n",
- "for count, data in enumerate(tqdm(dataset)):\n",
- " num_samples = max(data.pos_ref.size(0) // data.num_nodes, 1)\n",
- "\n",
- " data_input = data.clone()\n",
- " data_input['pos_ref'] = None\n",
- " batch = repeat_data(data_input, num_samples).to(DEVICE)\n",
- "\n",
- " # initial configuration\n",
- " pos_init = torch.randn(batch.num_nodes, 3).to(DEVICE)\n",
- "\n",
- " # for logging animation of denoising\n",
- " pos_traj = []\n",
- " with torch.no_grad():\n",
- "\n",
- " # scale initial sample\n",
- " pos = pos_init * sigmas[-1]\n",
- " for t in scheduler.timesteps:\n",
- " batch.pos = pos\n",
- "\n",
- " # generate geometry with model, then filter it\n",
- " epsilon = model.forward(batch, t, sigma=sigmas[t], return_dict=False)[0]\n",
- "\n",
- " # Update\n",
- " reconstructed_pos = scheduler.step(epsilon, t, pos)[\"prev_sample\"].to(DEVICE)\n",
- "\n",
- " pos = reconstructed_pos\n",
- "\n",
- " if torch.isnan(pos).any():\n",
- " print(\"NaN detected. Please restart.\")\n",
- " raise FloatingPointError()\n",
- "\n",
- " # recenter graph of positions for next iteration\n",
- " pos = pos - scatter_mean(pos, batch.batch, dim=0)[batch.batch]\n",
- "\n",
- " # optional clipping\n",
- " if clip_pos is not None:\n",
- " pos = torch.clamp(pos, min=-clip_pos, max=clip_pos)\n",
- " pos_traj.append(pos.clone().cpu())\n",
- "\n",
- " pos_gen = pos.cpu()\n",
- " if save_traj:\n",
- " pos_gen_traj = pos_traj.cpu()\n",
- " data.pos_gen = torch.stack(pos_gen_traj)\n",
- " else:\n",
- " data.pos_gen = pos_gen\n",
- " results.append(data)\n",
- "\n",
- "\n",
- "if save_data:\n",
- " save_path = os.path.join(output_dir, 'samples_all.pkl')\n",
- "\n",
- " with open(save_path, 'wb') as f:\n",
- " pickle.dump(results, f)"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Render the results!"
- ],
- "metadata": {
- "id": "fSApwSaZNndW"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "d47Zxo2OKdgZ"
- },
- "source": [
- "This function allows us to render 3d in colab."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "e9Cd0kCAv9b8"
- },
- "outputs": [],
- "source": [
- "from google.colab import output\n",
- "output.enable_custom_widget_manager()"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Helper functions"
- ],
- "metadata": {
- "id": "RjaVuR15NqzF"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "28rBYa9NKhlz"
- },
- "source": [
- "Here is a helper function for copying the generated tensors into a format used by RDKit & NGLViewer."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "LKdKdwxcyTQ6"
- },
- "outputs": [],
- "source": [
- "from copy import deepcopy\n",
- "def set_rdmol_positions(rdkit_mol, pos):\n",
- " \"\"\"\n",
- " Args:\n",
- " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
- " pos: (N_atoms, 3)\n",
- " \"\"\"\n",
- " mol = deepcopy(rdkit_mol)\n",
- " set_rdmol_positions_(mol, pos)\n",
- " return mol\n",
- "\n",
- "def set_rdmol_positions_(mol, pos):\n",
- " \"\"\"\n",
- " Args:\n",
- " rdkit_mol: An `rdkit.Chem.rdchem.Mol` object.\n",
- " pos: (N_atoms, 3)\n",
- " \"\"\"\n",
- " for i in range(pos.shape[0]):\n",
- " mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())\n",
- " return mol\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "NuE10hcpKmzK"
- },
- "source": [
- "Process the generated data to make it easy to view."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "KieVE1vc0_Vs",
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "outputId": "6faa185d-b1bc-47e8-be18-30d1e557e7c8"
- },
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "collect 5 generated molecules in `mols`\n"
- ]
- }
- ],
- "source": [
- "# the model can generate multiple conformations per 2d geometry\n",
- "num_gen = results[0]['pos_gen'].shape[0]\n",
- "\n",
- "# init storage objects\n",
- "mols_gen = []\n",
- "mols_orig = []\n",
- "for to_process in results:\n",
- "\n",
- " # store the reference 3d position\n",
- " to_process['pos_ref'] = to_process['pos_ref'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n",
- "\n",
- " # store the generated 3d position\n",
- " to_process['pos_gen'] = to_process['pos_gen'].reshape(-1, to_process['rdmol'].GetNumAtoms(), 3)\n",
- "\n",
- " # copy data to new object\n",
- " new_mol = set_rdmol_positions(to_process.rdmol, to_process['pos_gen'][0])\n",
- "\n",
- " # append results\n",
- " mols_gen.append(new_mol)\n",
- " mols_orig.append(to_process.rdmol)\n",
- "\n",
- "print(f\"collect {len(mols_gen)} generated molecules in `mols`\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "tin89JwMKp4v"
- },
- "source": [
- "Import tools to visualize the 2d chemical diagram of the molecule."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "yqV6gllSZn38"
- },
- "outputs": [],
- "source": [
- "from rdkit.Chem import AllChem\n",
- "from rdkit import Chem\n",
- "from rdkit.Chem.Draw import rdMolDraw2D as MD2\n",
- "from IPython.display import SVG, display"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TFNKmGddVoOk"
- },
- "source": [
- "Select molecule to visualize"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "KzuwLlrrVaGc"
- },
- "outputs": [],
- "source": [
- "idx = 0\n",
- "assert idx < len(results), \"selected molecule that was not generated\""
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "### Viewing"
- ],
- "metadata": {
- "id": "hkb8w0_SNtU8"
- }
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "I3R4QBQeKttN"
- },
- "source": [
- "This 2D rendering is the equivalent of the **input to the model**!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "gkQRWjraaKex",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 321
- },
- "outputId": "9c3d1a91-a51d-475d-9e34-2be2459abc47"
- },
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- ""
- ],
- "image/svg+xml": "\n\n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n "
+ "clipNear": 0,
+ "clipRadius": 0,
+ "colorMode": "hcl",
+ "colorReverse": false,
+ "colorScale": "",
+ "colorScheme": "element",
+ "colorValue": 9474192,
+ "cylinderOnly": false,
+ "defaultAssembly": "",
+ "depthWrite": true,
+ "diffuse": 16777215,
+ "diffuseInterior": false,
+ "disableImpostor": false,
+ "disablePicking": false,
+ "flatShaded": false,
+ "interiorColor": 2236962,
+ "interiorDarkening": 0,
+ "lazy": false,
+ "lineOnly": false,
+ "linewidth": 2,
+ "matrix": {
+ "elements": [
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1
+ ]
},
- "metadata": {}
+ "metalness": 0,
+ "multipleBond": "off",
+ "opacity": 1,
+ "openEnded": true,
+ "quality": "high",
+ "radialSegments": 20,
+ "radiusData": {},
+ "radiusScale": 2,
+ "radiusSize": 0.15,
+ "radiusType": "size",
+ "roughness": 0.4,
+ "sele": "",
+ "side": "double",
+ "sphereDetail": 2,
+ "useInteriorColor": true,
+ "visible": true,
+ "wireframe": false
+ },
+ "type": "ball+stick"
}
- ],
- "source": [
- "mc = Chem.MolFromSmiles(dataset[0]['smiles'])\n",
- "molSize=(450,300)\n",
- "drawer = MD2.MolDraw2DSVG(molSize[0],molSize[1])\n",
- "drawer.DrawMolecule(mc)\n",
- "drawer.FinishDrawing()\n",
- "svg = drawer.GetDrawingText()\n",
- "display(SVG(svg.replace('svg:','')))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "z4FDMYMxKw2I"
+ }
},
- "source": [
- "Generate the 3d molecule!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "aT1Bkb8YxJfV",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 17,
- "referenced_widgets": [
- "695ab5bbf30a4ab19df1f9f33469f314",
- "eac6a8dcdc9d4335a2e51031793ead29"
- ]
- },
- "outputId": "b98870ae-049d-4386-b676-166e9526bda2"
- },
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "695ab5bbf30a4ab19df1f9f33469f314"
- }
- },
- "metadata": {
- "application/vnd.jupyter.widget-view+json": {
- "colab": {
- "custom_widget_manager": {
- "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
- }
- }
- }
- }
- }
+ "_ngl_serialize": false,
+ "_ngl_version": "",
+ "_ngl_view_id": [
+ "FB989FD1-5B9C-446B-8914-6B58AF85446D"
],
- "source": [
- "from nglview import show_rdkit as show"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "pxtq8I-I18C-",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 337,
- "referenced_widgets": [
- "be446195da2b4ff2aec21ec5ff963a54",
- "c6596896148b4a8a9c57963b67c7782f",
- "2489b5e5648541fbbdceadb05632a050",
- "01e0ba4e5da04914b4652b8d58565d7b",
- "c30e6c2f3e2a44dbbb3d63bd519acaa4",
- "f31c6e40e9b2466a9064a2669933ecd5",
- "19308ccac642498ab8b58462e3f1b0bb",
- "4a081cdc2ec3421ca79dd933b7e2b0c4",
- "e5c0d75eb5e1447abd560c8f2c6017e1",
- "5146907ef6764654ad7d598baebc8b58",
- "144ec959b7604a2cabb5ca46ae5e5379",
- "abce2a80e6304df3899109c6d6cac199",
- "65195cb7a4134f4887e9dd19f3676462"
- ]
- },
- "outputId": "72ed63ac-d2ec-4f5c-a0b1-4e7c1840a4e7"
- },
- "outputs": [
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "NGLWidget()"
- ],
- "application/vnd.jupyter.widget-view+json": {
- "version_major": 2,
- "version_minor": 0,
- "model_id": "be446195da2b4ff2aec21ec5ff963a54"
- }
- },
- "metadata": {
- "application/vnd.jupyter.widget-view+json": {
- "colab": {
- "custom_widget_manager": {
- "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/d2e234f7cc04bf79/manager.min.js"
- }
- }
- }
- }
- }
+ "_player_dict": {},
+ "_scene_position": {},
+ "_scene_rotation": {},
+ "_synced_model_ids": [],
+ "_synced_repr_model_ids": [],
+ "_view_count": null,
+ "_view_height": "",
+ "_view_module": "nglview-js-widgets",
+ "_view_module_version": "3.0.1",
+ "_view_name": "NGLView",
+ "_view_width": "",
+ "background": "white",
+ "frame": 0,
+ "gui_style": null,
+ "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f",
+ "max_frame": 0,
+ "n_components": 2,
+ "picked": {}
+ }
+ },
+ "c2321b3bff6f490ca12040a20308f555": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd",
+ "max": 3271865,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d",
+ "value": 3271865
+ }
+ },
+ "c30e6c2f3e2a44dbbb3d63bd519acaa4": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c6596896148b4a8a9c57963b67c7782f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d022575f1fa2446d891650897f187b4d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "d90f304e9560472eacfbdd11e46765eb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c",
+ "IPY_MODEL_c2321b3bff6f490ca12040a20308f555",
+ "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d"
],
- "source": [
- "# new molecule\n",
- "show(mols_gen[idx])"
- ]
- },
- {
- "cell_type": "code",
- "source": [],
- "metadata": {
- "id": "KJr4h2mwXeTo"
- },
- "execution_count": null,
- "outputs": []
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "provenance": []
- },
- "gpuClass": "standard",
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "name": "python"
- },
- "widgets": {
- "application/vnd.jupyter.widget-state+json": {
- "d90f304e9560472eacfbdd11e46765eb": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_1c6246f15b654f4daa11c9bcf997b78c",
- "IPY_MODEL_c2321b3bff6f490ca12040a20308f555",
- "IPY_MODEL_b7feb522161f4cf4b7cc7c1a078ff12d"
- ],
- "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f"
- }
- },
- "1c6246f15b654f4daa11c9bcf997b78c": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_bbef741e76ec41b7ab7187b487a383df",
- "placeholder": "",
- "style": "IPY_MODEL_561f742d418d4721b0670cc8dd62e22c",
- "value": "Downloading: 100%"
- }
- },
- "c2321b3bff6f490ca12040a20308f555": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "FloatProgressModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "FloatProgressModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ProgressView",
- "bar_style": "success",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_872915dd1bb84f538c44e26badabafdd",
- "max": 3271865,
- "min": 0,
- "orientation": "horizontal",
- "style": "IPY_MODEL_d022575f1fa2446d891650897f187b4d",
- "value": 3271865
- }
- },
- "b7feb522161f4cf4b7cc7c1a078ff12d": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_fdc393f3468c432aa0ada05e238a5436",
- "placeholder": "",
- "style": "IPY_MODEL_2c9362906e4b40189f16d14aa9a348da",
- "value": " 3.27M/3.27M [00:01<00:00, 3.25MB/s]"
- }
- },
- "e2d368556e494ae7ae4e2e992af2cd4f": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "bbef741e76ec41b7ab7187b487a383df": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "561f742d418d4721b0670cc8dd62e22c": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "872915dd1bb84f538c44e26badabafdd": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "d022575f1fa2446d891650897f187b4d": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ProgressStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ProgressStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "bar_color": null,
- "description_width": ""
- }
- },
- "fdc393f3468c432aa0ada05e238a5436": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "2c9362906e4b40189f16d14aa9a348da": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "6010fc8daa7a44d5aec4b830ec2ebaa1": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_7e0bb1b8d65249d3974200686b193be2",
- "IPY_MODEL_ba98aa6d6a884e4ab8bbb5dfb5e4cf7a",
- "IPY_MODEL_6526646be5ed415c84d1245b040e629b"
- ],
- "layout": "IPY_MODEL_24d31fc3576e43dd9f8301d2ef3a37ab"
- }
- },
- "7e0bb1b8d65249d3974200686b193be2": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_2918bfaadc8d4b1a9832522c40dfefb8",
- "placeholder": "",
- "style": "IPY_MODEL_a4bfdca35cc54dae8812720f1b276a08",
- "value": "Downloading: 100%"
- }
- },
- "ba98aa6d6a884e4ab8bbb5dfb5e4cf7a": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "FloatProgressModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "FloatProgressModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ProgressView",
- "bar_style": "success",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_e4901541199b45c6a18824627692fc39",
- "max": 401,
- "min": 0,
- "orientation": "horizontal",
- "style": "IPY_MODEL_f915cf874246446595206221e900b2fe",
- "value": 401
- }
- },
- "6526646be5ed415c84d1245b040e629b": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HTMLModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HTMLModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HTMLView",
- "description": "",
- "description_tooltip": null,
- "layout": "IPY_MODEL_a9e388f22a9742aaaf538e22575c9433",
- "placeholder": "",
- "style": "IPY_MODEL_42f6c3db29d7484ba6b4f73590abd2f4",
- "value": " 401/401 [00:00<00:00, 13.5kB/s]"
- }
- },
- "24d31fc3576e43dd9f8301d2ef3a37ab": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "2918bfaadc8d4b1a9832522c40dfefb8": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "a4bfdca35cc54dae8812720f1b276a08": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "e4901541199b45c6a18824627692fc39": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "f915cf874246446595206221e900b2fe": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ProgressStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ProgressStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "bar_color": null,
- "description_width": ""
- }
- },
- "a9e388f22a9742aaaf538e22575c9433": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "42f6c3db29d7484ba6b4f73590abd2f4": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "695ab5bbf30a4ab19df1f9f33469f314": {
- "model_module": "nglview-js-widgets",
- "model_name": "ColormakerRegistryModel",
- "model_module_version": "3.0.1",
- "state": {
- "_dom_classes": [],
- "_model_module": "nglview-js-widgets",
- "_model_module_version": "3.0.1",
- "_model_name": "ColormakerRegistryModel",
- "_msg_ar": [],
- "_msg_q": [],
- "_ready": false,
- "_view_count": null,
- "_view_module": "nglview-js-widgets",
- "_view_module_version": "3.0.1",
- "_view_name": "ColormakerRegistryView",
- "layout": "IPY_MODEL_eac6a8dcdc9d4335a2e51031793ead29"
- }
- },
- "eac6a8dcdc9d4335a2e51031793ead29": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "be446195da2b4ff2aec21ec5ff963a54": {
- "model_module": "nglview-js-widgets",
- "model_name": "NGLModel",
- "model_module_version": "3.0.1",
- "state": {
- "_camera_orientation": [
- -15.519693580202304,
- -14.065056548036177,
- -23.53197484807691,
- 0,
- -23.357853515109753,
- 20.94055073042662,
- 2.888695042134944,
- 0,
- 14.352363398292777,
- 18.870825741878015,
- -20.744689572909344,
- 0,
- 0.2724999189376831,
- 0.6940000057220459,
- -0.3734999895095825,
- 1
- ],
- "_camera_str": "orthographic",
- "_dom_classes": [],
- "_gui_theme": null,
- "_ibtn_fullscreen": "IPY_MODEL_2489b5e5648541fbbdceadb05632a050",
- "_igui": null,
- "_iplayer": "IPY_MODEL_01e0ba4e5da04914b4652b8d58565d7b",
- "_model_module": "nglview-js-widgets",
- "_model_module_version": "3.0.1",
- "_model_name": "NGLModel",
- "_ngl_color_dict": {},
- "_ngl_coordinate_resource": {},
- "_ngl_full_stage_parameters": {
- "impostor": true,
- "quality": "medium",
- "workerDefault": true,
- "sampleLevel": 0,
- "backgroundColor": "white",
- "rotateSpeed": 2,
- "zoomSpeed": 1.2,
- "panSpeed": 1,
- "clipNear": 0,
- "clipFar": 100,
- "clipDist": 10,
- "fogNear": 50,
- "fogFar": 100,
- "cameraFov": 40,
- "cameraEyeSep": 0.3,
- "cameraType": "perspective",
- "lightColor": 14540253,
- "lightIntensity": 1,
- "ambientColor": 14540253,
- "ambientIntensity": 0.2,
- "hoverTimeout": 0,
- "tooltip": true,
- "mousePreset": "default"
- },
- "_ngl_msg_archive": [
- {
- "target": "Stage",
- "type": "call_method",
- "methodName": "loadFile",
- "reconstruc_color_scheme": false,
- "args": [
- {
- "type": "blob",
- "data": "HETATM 1 C1 UNL 1 -0.025 3.128 2.316 1.00 0.00 C \nHETATM 2 H1 UNL 1 0.183 3.657 2.823 1.00 0.00 H \nHETATM 3 C2 UNL 1 0.590 3.559 0.963 1.00 0.00 C \nHETATM 4 C3 UNL 1 0.056 4.479 0.406 1.00 0.00 C \nHETATM 5 C4 UNL 1 -0.219 4.802 -1.065 1.00 0.00 C \nHETATM 6 H2 UNL 1 0.686 4.431 -1.575 1.00 0.00 H \nHETATM 7 H3 UNL 1 -0.524 5.217 -1.274 1.00 0.00 H \nHETATM 8 C5 UNL 1 -1.284 3.766 -1.342 1.00 0.00 C \nHETATM 9 N1 UNL 1 -1.073 2.494 -0.580 1.00 0.00 N \nHETATM 10 C6 UNL 1 -1.909 1.494 -0.964 1.00 0.00 C \nHETATM 11 O1 UNL 1 -2.487 1.531 -2.092 1.00 0.00 O \nHETATM 12 C7 UNL 1 -2.232 0.242 -0.130 1.00 0.00 C \nHETATM 13 C8 UNL 1 -2.161 -1.057 -1.037 1.00 0.00 C \nHETATM 14 C9 UNL 1 -0.744 -1.111 -1.610 1.00 0.00 C \nHETATM 15 N2 UNL 1 0.290 -0.917 -0.628 1.00 0.00 N \nHETATM 16 S1 UNL 1 1.717 -1.597 -0.914 1.00 0.00 S \nHETATM 17 O2 UNL 1 1.960 -1.671 -2.338 1.00 0.00 O \nHETATM 18 O3 UNL 1 2.713 -0.968 -0.082 1.00 0.00 O \nHETATM 19 C10 UNL 1 1.425 -3.170 -0.345 1.00 0.00 C \nHETATM 20 C11 UNL 1 1.225 -4.400 -1.271 1.00 0.00 C \nHETATM 21 C12 UNL 1 1.314 -5.913 -0.895 1.00 0.00 C \nHETATM 22 C13 UNL 1 1.823 -6.229 0.386 1.00 0.00 C \nHETATM 23 C14 UNL 1 2.031 -5.110 1.365 1.00 0.00 C \nHETATM 24 N3 UNL 1 1.850 -5.267 2.712 1.00 0.00 N \nHETATM 25 O4 UNL 1 1.382 -4.029 3.126 1.00 0.00 O \nHETATM 26 N4 UNL 1 1.300 -3.023 2.154 1.00 0.00 N \nHETATM 27 C15 UNL 1 1.731 -3.672 1.032 1.00 0.00 C \nHETATM 28 H4 UNL 1 2.380 -6.874 0.436 1.00 0.00 H \nHETATM 29 H5 UNL 1 0.704 -6.526 -1.420 1.00 0.00 H \nHETATM 30 H6 UNL 1 1.144 -4.035 -2.291 1.00 0.00 H \nHETATM 31 C16 UNL 1 0.044 -0.371 0.685 1.00 0.00 C \nHETATM 32 C17 UNL 1 -1.352 -0.045 1.077 1.00 0.00 C \nHETATM 33 H7 UNL 1 -1.395 0.770 1.768 1.00 0.00 H \nHETATM 34 H8 UNL 1 -1.792 -0.941 1.582 1.00 0.00 H \nHETATM 35 H9 UNL 1 0.583 -1.035 1.393 1.00 0.00 H \nHETATM 36 H10 UNL 1 0.664 0.613 0.663 1.00 0.00 H \nHETATM 37 H11 UNL 1 -0.631 -0.267 -2.335 1.00 0.00 H \nHETATM 38 H12 UNL 1 -0.571 -2.046 -2.098 1.00 0.00 H \nHETATM 39 H13 UNL 1 -2.872 -0.992 -1.826 1.00 0.00 H \nHETATM 40 H14 UNL 1 -2.370 -1.924 -0.444 1.00 0.00 H \nHETATM 41 H15 UNL 1 -3.258 0.364 0.197 1.00 0.00 H \nHETATM 42 C18 UNL 1 0.276 2.337 -0.078 1.00 0.00 C \nHETATM 43 H16 UNL 1 0.514 1.371 0.252 1.00 0.00 H \nHETATM 44 H17 UNL 1 0.988 2.413 -0.949 1.00 0.00 H \nHETATM 45 H18 UNL 1 -1.349 3.451 -2.379 1.00 0.00 H \nHETATM 46 H19 UNL 1 -2.224 4.055 -0.958 1.00 0.00 H \nHETATM 47 H20 UNL 1 0.793 5.486 0.669 1.00 0.00 H \nHETATM 48 H21 UNL 1 -0.849 4.974 0.937 1.00 0.00 H \nHETATM 49 H22 UNL 1 1.667 3.431 1.070 1.00 0.00 H \nHETATM 50 H23 UNL 1 0.379 2.143 2.689 1.00 0.00 H \nHETATM 51 H24 UNL 1 -1.094 2.983 2.223 1.00 0.00 H \nCONECT 1 2 3 50 51\nCONECT 3 4 42 49\nCONECT 4 5 47 48\nCONECT 5 6 7 8\nCONECT 8 9 45 46\nCONECT 9 10 42\nCONECT 10 11 11 12\nCONECT 12 13 32 41\nCONECT 13 14 39 40\nCONECT 14 15 37 38\nCONECT 15 16 31\nCONECT 16 17 17 18 18\nCONECT 16 19\nCONECT 19 20 20 27\nCONECT 20 21 30\nCONECT 21 22 22 29\nCONECT 22 23 28\nCONECT 23 24 24 27\nCONECT 24 25\nCONECT 25 26\nCONECT 26 27 27\nCONECT 31 32 35 36\nCONECT 32 33 34\nCONECT 42 43 44\nEND\n",
- "binary": false
- }
- ],
- "kwargs": {
- "defaultRepresentation": true,
- "ext": "pdb"
- }
- }
- ],
- "_ngl_original_stage_parameters": {
- "impostor": true,
- "quality": "medium",
- "workerDefault": true,
- "sampleLevel": 0,
- "backgroundColor": "white",
- "rotateSpeed": 2,
- "zoomSpeed": 1.2,
- "panSpeed": 1,
- "clipNear": 0,
- "clipFar": 100,
- "clipDist": 10,
- "fogNear": 50,
- "fogFar": 100,
- "cameraFov": 40,
- "cameraEyeSep": 0.3,
- "cameraType": "perspective",
- "lightColor": 14540253,
- "lightIntensity": 1,
- "ambientColor": 14540253,
- "ambientIntensity": 0.2,
- "hoverTimeout": 0,
- "tooltip": true,
- "mousePreset": "default"
- },
- "_ngl_repr_dict": {
- "0": {
- "0": {
- "type": "ball+stick",
- "params": {
- "lazy": false,
- "visible": true,
- "quality": "high",
- "sphereDetail": 2,
- "radialSegments": 20,
- "openEnded": true,
- "disableImpostor": false,
- "aspectRatio": 1.5,
- "lineOnly": false,
- "cylinderOnly": false,
- "multipleBond": "off",
- "bondScale": 0.3,
- "bondSpacing": 0.75,
- "linewidth": 2,
- "radiusType": "size",
- "radiusData": {},
- "radiusSize": 0.15,
- "radiusScale": 2,
- "assembly": "default",
- "defaultAssembly": "",
- "clipNear": 0,
- "clipRadius": 0,
- "clipCenter": {
- "x": 0,
- "y": 0,
- "z": 0
- },
- "flatShaded": false,
- "opacity": 1,
- "depthWrite": true,
- "side": "double",
- "wireframe": false,
- "colorScheme": "element",
- "colorScale": "",
- "colorReverse": false,
- "colorValue": 9474192,
- "colorMode": "hcl",
- "roughness": 0.4,
- "metalness": 0,
- "diffuse": 16777215,
- "diffuseInterior": false,
- "useInteriorColor": true,
- "interiorColor": 2236962,
- "interiorDarkening": 0,
- "matrix": {
- "elements": [
- 1,
- 0,
- 0,
- 0,
- 0,
- 1,
- 0,
- 0,
- 0,
- 0,
- 1,
- 0,
- 0,
- 0,
- 0,
- 1
- ]
- },
- "disablePicking": false,
- "sele": ""
- }
- }
- },
- "1": {
- "0": {
- "type": "ball+stick",
- "params": {
- "lazy": false,
- "visible": true,
- "quality": "high",
- "sphereDetail": 2,
- "radialSegments": 20,
- "openEnded": true,
- "disableImpostor": false,
- "aspectRatio": 1.5,
- "lineOnly": false,
- "cylinderOnly": false,
- "multipleBond": "off",
- "bondScale": 0.3,
- "bondSpacing": 0.75,
- "linewidth": 2,
- "radiusType": "size",
- "radiusData": {},
- "radiusSize": 0.15,
- "radiusScale": 2,
- "assembly": "default",
- "defaultAssembly": "",
- "clipNear": 0,
- "clipRadius": 0,
- "clipCenter": {
- "x": 0,
- "y": 0,
- "z": 0
- },
- "flatShaded": false,
- "opacity": 1,
- "depthWrite": true,
- "side": "double",
- "wireframe": false,
- "colorScheme": "element",
- "colorScale": "",
- "colorReverse": false,
- "colorValue": 9474192,
- "colorMode": "hcl",
- "roughness": 0.4,
- "metalness": 0,
- "diffuse": 16777215,
- "diffuseInterior": false,
- "useInteriorColor": true,
- "interiorColor": 2236962,
- "interiorDarkening": 0,
- "matrix": {
- "elements": [
- 1,
- 0,
- 0,
- 0,
- 0,
- 1,
- 0,
- 0,
- 0,
- 0,
- 1,
- 0,
- 0,
- 0,
- 0,
- 1
- ]
- },
- "disablePicking": false,
- "sele": ""
- }
- }
- }
- },
- "_ngl_serialize": false,
- "_ngl_version": "",
- "_ngl_view_id": [
- "FB989FD1-5B9C-446B-8914-6B58AF85446D"
- ],
- "_player_dict": {},
- "_scene_position": {},
- "_scene_rotation": {},
- "_synced_model_ids": [],
- "_synced_repr_model_ids": [],
- "_view_count": null,
- "_view_height": "",
- "_view_module": "nglview-js-widgets",
- "_view_module_version": "3.0.1",
- "_view_name": "NGLView",
- "_view_width": "",
- "background": "white",
- "frame": 0,
- "gui_style": null,
- "layout": "IPY_MODEL_c6596896148b4a8a9c57963b67c7782f",
- "max_frame": 0,
- "n_components": 2,
- "picked": {}
- }
- },
- "c6596896148b4a8a9c57963b67c7782f": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "2489b5e5648541fbbdceadb05632a050": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ButtonModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ButtonModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "ButtonView",
- "button_style": "",
- "description": "",
- "disabled": false,
- "icon": "compress",
- "layout": "IPY_MODEL_abce2a80e6304df3899109c6d6cac199",
- "style": "IPY_MODEL_65195cb7a4134f4887e9dd19f3676462",
- "tooltip": ""
- }
- },
- "01e0ba4e5da04914b4652b8d58565d7b": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "HBoxModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "HBoxModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "HBoxView",
- "box_style": "",
- "children": [
- "IPY_MODEL_e5c0d75eb5e1447abd560c8f2c6017e1",
- "IPY_MODEL_5146907ef6764654ad7d598baebc8b58"
- ],
- "layout": "IPY_MODEL_144ec959b7604a2cabb5ca46ae5e5379"
- }
- },
- "c30e6c2f3e2a44dbbb3d63bd519acaa4": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "f31c6e40e9b2466a9064a2669933ecd5": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "DescriptionStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "DescriptionStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": ""
- }
- },
- "19308ccac642498ab8b58462e3f1b0bb": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "4a081cdc2ec3421ca79dd933b7e2b0c4": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "SliderStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "SliderStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "description_width": "",
- "handle_color": null
- }
- },
- "e5c0d75eb5e1447abd560c8f2c6017e1": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "PlayModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "PlayModel",
- "_playing": false,
- "_repeat": false,
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "PlayView",
- "description": "",
- "description_tooltip": null,
- "disabled": false,
- "interval": 100,
- "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4",
- "max": 0,
- "min": 0,
- "show_repeat": true,
- "step": 1,
- "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5",
- "value": 0
- }
- },
- "5146907ef6764654ad7d598baebc8b58": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "IntSliderModel",
- "model_module_version": "1.5.0",
- "state": {
- "_dom_classes": [],
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "IntSliderModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/controls",
- "_view_module_version": "1.5.0",
- "_view_name": "IntSliderView",
- "continuous_update": true,
- "description": "",
- "description_tooltip": null,
- "disabled": false,
- "layout": "IPY_MODEL_19308ccac642498ab8b58462e3f1b0bb",
- "max": 0,
- "min": 0,
- "orientation": "horizontal",
- "readout": true,
- "readout_format": "d",
- "step": 1,
- "style": "IPY_MODEL_4a081cdc2ec3421ca79dd933b7e2b0c4",
- "value": 0
- }
- },
- "144ec959b7604a2cabb5ca46ae5e5379": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": null
- }
- },
- "abce2a80e6304df3899109c6d6cac199": {
- "model_module": "@jupyter-widgets/base",
- "model_name": "LayoutModel",
- "model_module_version": "1.2.0",
- "state": {
- "_model_module": "@jupyter-widgets/base",
- "_model_module_version": "1.2.0",
- "_model_name": "LayoutModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "LayoutView",
- "align_content": null,
- "align_items": null,
- "align_self": null,
- "border": null,
- "bottom": null,
- "display": null,
- "flex": null,
- "flex_flow": null,
- "grid_area": null,
- "grid_auto_columns": null,
- "grid_auto_flow": null,
- "grid_auto_rows": null,
- "grid_column": null,
- "grid_gap": null,
- "grid_row": null,
- "grid_template_areas": null,
- "grid_template_columns": null,
- "grid_template_rows": null,
- "height": null,
- "justify_content": null,
- "justify_items": null,
- "left": null,
- "margin": null,
- "max_height": null,
- "max_width": null,
- "min_height": null,
- "min_width": null,
- "object_fit": null,
- "object_position": null,
- "order": null,
- "overflow": null,
- "overflow_x": null,
- "overflow_y": null,
- "padding": null,
- "right": null,
- "top": null,
- "visibility": null,
- "width": "34px"
- }
- },
- "65195cb7a4134f4887e9dd19f3676462": {
- "model_module": "@jupyter-widgets/controls",
- "model_name": "ButtonStyleModel",
- "model_module_version": "1.5.0",
- "state": {
- "_model_module": "@jupyter-widgets/controls",
- "_model_module_version": "1.5.0",
- "_model_name": "ButtonStyleModel",
- "_view_count": null,
- "_view_module": "@jupyter-widgets/base",
- "_view_module_version": "1.2.0",
- "_view_name": "StyleView",
- "button_color": null,
- "font_weight": ""
- }
- }
- }
+ "layout": "IPY_MODEL_e2d368556e494ae7ae4e2e992af2cd4f"
+ }
+ },
+ "e2d368556e494ae7ae4e2e992af2cd4f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e4901541199b45c6a18824627692fc39": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e5c0d75eb5e1447abd560c8f2c6017e1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "PlayModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "PlayModel",
+ "_playing": false,
+ "_repeat": false,
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "PlayView",
+ "description": "",
+ "description_tooltip": null,
+ "disabled": false,
+ "interval": 100,
+ "layout": "IPY_MODEL_c30e6c2f3e2a44dbbb3d63bd519acaa4",
+ "max": 0,
+ "min": 0,
+ "show_repeat": true,
+ "step": 1,
+ "style": "IPY_MODEL_f31c6e40e9b2466a9064a2669933ecd5",
+ "value": 0
+ }
+ },
+ "eac6a8dcdc9d4335a2e51031793ead29": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f31c6e40e9b2466a9064a2669933ecd5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "f915cf874246446595206221e900b2fe": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "fdc393f3468c432aa0ada05e238a5436": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
}
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
\ No newline at end of file
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/research_projects/gligen/README.md b/examples/research_projects/gligen/README.md
index fa922617d984..3da23306ce1a 100644
--- a/examples/research_projects/gligen/README.md
+++ b/examples/research_projects/gligen/README.md
@@ -47,11 +47,11 @@ pip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps
Download the pre-trained model:
```bash
-huggingface-cli download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
-huggingface-cli download --resume-download IDEA-Research/grounding-dino-base
-huggingface-cli download --resume-download Salesforce/blip2-flan-t5-xxl
-huggingface-cli download --resume-download clip-vit-large-patch14
-huggingface-cli download --resume-download masterful/gligen-1-4-generation-text-box
+hf download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
+hf download --resume-download IDEA-Research/grounding-dino-base
+hf download --resume-download Salesforce/blip2-flan-t5-xxl
+hf download --resume-download clip-vit-large-patch14
+hf download --resume-download masterful/gligen-1-4-generation-text-box
```
Make the training data on 8 GPUs:
@@ -66,7 +66,7 @@ torchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \
You can download the COCO training data from
```bash
-huggingface-cli download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
+hf download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
```
It's in the format of
@@ -125,7 +125,7 @@ Note that although the pre-trained GLIGEN model has been loaded, the parameters
The trained model can be downloaded from
```bash
-huggingface-cli download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
+hf download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
```
You can run `demo.ipynb` to visualize the generated images.
diff --git a/examples/research_projects/gligen/demo.ipynb b/examples/research_projects/gligen/demo.ipynb
index 571f1a0323a2..315aee710594 100644
--- a/examples/research_projects/gligen/demo.ipynb
+++ b/examples/research_projects/gligen/demo.ipynb
@@ -26,8 +26,7 @@
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
- "import torch\n",
- "from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
+ "from diffusers import StableDiffusionGLIGENPipeline"
]
},
{
@@ -36,28 +35,25 @@
"metadata": {},
"outputs": [],
"source": [
- "import os\n",
+ "from transformers import CLIPTextModel, CLIPTokenizer\n",
+ "\n",
"import diffusers\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" DDPMScheduler,\n",
- " UNet2DConditionModel,\n",
- " UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n",
+ " UNet2DConditionModel,\n",
")\n",
- "from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
+ "\n",
+ "\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
- "pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
+ "pretrained_model_name_or_path = \"/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83\"\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
- "text_encoder = CLIPTextModel.from_pretrained(\n",
- " pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
- ")\n",
- "vae = AutoencoderKL.from_pretrained(\n",
- " pretrained_model_name_or_path, subfolder=\"vae\"\n",
- ")\n",
+ "text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder=\"text_encoder\")\n",
+ "vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder=\"vae\")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n",
@@ -71,9 +67,7 @@
"metadata": {},
"outputs": [],
"source": [
- "unet = UNet2DConditionModel.from_pretrained(\n",
- " '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
- ")"
+ "unet = UNet2DConditionModel.from_pretrained(\"/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO\")"
]
},
{
@@ -108,6 +102,9 @@
"metadata": {},
"outputs": [],
"source": [
+ "import numpy as np\n",
+ "\n",
+ "\n",
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n",
@@ -117,10 +114,8 @@
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n",
- "prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
- "gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
- "\n",
- "import numpy as np\n",
+ "prompt = \"An oil painting of a pink dolphin jumping on the left of a steam boat on the sea\"\n",
+ "gen_boxes = [(\"a steam boat\", [232, 225, 257, 149]), (\"a jumping pink dolphin\", [21, 249, 189, 123])]\n",
"\n",
"boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n",
@@ -166,7 +161,7 @@
"metadata": {},
"outputs": [],
"source": [
- "diffusers.utils.make_image_grid(images, 4, len(images)//4)"
+ "diffusers.utils.make_image_grid(images, 4, len(images) // 4)"
]
},
{
@@ -179,7 +174,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "densecaption",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -197,5 +192,5 @@
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 4
}
diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
index 070cdad15564..06079fe9ed41 100644
--- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
+++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -15,8 +15,8 @@
# limitations under the License.
"""
- Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
- Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
+Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
+Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
import argparse
@@ -54,6 +54,7 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, cast_training_params
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
+from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -344,7 +345,7 @@ def parse_args():
"--conditioning_dropout_prob",
type=float,
default=None,
- help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://arxiv.org/abs/2211.09800.",
+ help="Conditioning dropout probability. Drops out the conditionings (image and edit prompt) used in training InstructPix2Pix. See section 3.2.1 in the paper: https://huggingface.co/papers/2211.09800.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -475,7 +476,7 @@ def convert_to_np(image, resolution):
def download_image(url):
- image = PIL.Image.open(requests.get(url, stream=True).raw)
+ image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
@@ -487,7 +488,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
@@ -973,7 +974,7 @@ def collate_fn(examples):
original_image_embeds = vae.encode(batch["original_pixel_values"].to(weight_dtype)).latent_dist.mode()
# Conditioning dropout to support classifier-free guidance during inference. For more details
- # check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
+ # check out the section 3.2.1 of the original paper https://huggingface.co/papers/2211.09800.
if args.conditioning_dropout_prob is not None:
random_p = torch.rand(bsz, device=latents.device, generator=generator)
# Sample masks for the edit prompts.
diff --git a/examples/research_projects/intel_opts/inference_bf16.py b/examples/research_projects/intel_opts/inference_bf16.py
index 96ec709f433c..13f2731fb713 100644
--- a/examples/research_projects/intel_opts/inference_bf16.py
+++ b/examples/research_projects/intel_opts/inference_bf16.py
@@ -13,7 +13,7 @@
device = "cpu"
-prompt = "a lovely in red dress and hat, in the snowly and brightly night, with many brighly buildings"
+prompt = "a lovely in red dress and hat, in the snowly and brightly night, with many brightly buildings"
model_id = "path-to-your-trained-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
diff --git a/examples/research_projects/intel_opts/textual_inversion/README.md b/examples/research_projects/intel_opts/textual_inversion/README.md
index 3339b8e2cb63..8efb14b47f28 100644
--- a/examples/research_projects/intel_opts/textual_inversion/README.md
+++ b/examples/research_projects/intel_opts/textual_inversion/README.md
@@ -1,6 +1,6 @@
## Textual Inversion fine-tuning example
-[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
## Training with Intel Extension for PyTorch
diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
index ea4a0d255b68..740a759420cb 100644
--- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
+++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
@@ -366,7 +366,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/intel_opts/textual_inversion_dfq/README.md b/examples/research_projects/intel_opts/textual_inversion_dfq/README.md
index 4a227cdb4d63..844227ed87aa 100644
--- a/examples/research_projects/intel_opts/textual_inversion_dfq/README.md
+++ b/examples/research_projects/intel_opts/textual_inversion_dfq/README.md
@@ -1,6 +1,6 @@
# Distillation for quantization on Textual Inversion models to personalize text2image
-[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
+[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images._By using just 3-5 images new concepts can be taught to Stable Diffusion and the model personalized on your own images_
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
We have enabled distillation for quantization in `textual_inversion.py` to do quantization aware training as well as distillation on the model generated by Textual Inversion method.
@@ -80,7 +80,7 @@ export INT8_MODEL_NAME="./int8_model"
python text2images.py \
--pretrained_model_name_or_path=$INT8_MODEL_NAME \
- --caption "a lovely in red dress and hat, in the snowly and brightly night, with many brighly buildings." \
+ --caption "a lovely in red dress and hat, in the snowly and brightly night, with many brightly buildings." \
--images_num 4
```
diff --git a/examples/research_projects/ip_adapter/README.md b/examples/research_projects/ip_adapter/README.md
index 04a6c86e5305..0bead5ae859d 100644
--- a/examples/research_projects/ip_adapter/README.md
+++ b/examples/research_projects/ip_adapter/README.md
@@ -1,6 +1,6 @@
# IP Adapter Training Example
-[IP Adapter](https://arxiv.org/abs/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.
+[IP Adapter](https://huggingface.co/papers/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.
## Training locally with PyTorch
@@ -55,7 +55,7 @@ The Accelerate launch command is used to train a model using multiple GPUs and m
```
accelerate launch --mixed_precision "fp16" \
tutorial_train_ip-adapter.py \
---pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
+--pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-v1-5/" \
--image_encoder_path="{image_encoder_path}" \
--data_json_file="{data.json}" \
--data_root_path="{image_path}" \
@@ -73,7 +73,7 @@ tutorial_train_ip-adapter.py \
```
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
tutorial_train_ip-adapter.py \
- --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
+ --pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-v1-5/" \
--image_encoder_path="{image_encoder_path}" \
--data_json_file="{data.json}" \
--data_root_path="{image_path}" \
diff --git a/examples/research_projects/lora/README.md b/examples/research_projects/lora/README.md
index 643f664ce1eb..55b870b0bc03 100644
--- a/examples/research_projects/lora/README.md
+++ b/examples/research_projects/lora/README.md
@@ -4,7 +4,7 @@ This is an experimental LoRA extension of [this example](https://github.com/hugg
## Training with LoRA
-Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
@@ -34,7 +34,7 @@ For this example we want to directly store the trained LoRA embeddings on the Hu
we need to be logged in and add the `--push_to_hub` flag.
```bash
-huggingface-cli login
+hf auth login
```
Now we can start training!
diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py
index a734c50d8ee0..a9079c114f9b 100644
--- a/examples/research_projects/lora/train_text_to_image_lora.py
+++ b/examples/research_projects/lora/train_text_to_image_lora.py
@@ -396,7 +396,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/multi_subject_dreambooth/README.md b/examples/research_projects/multi_subject_dreambooth/README.md
index 7c2e6f400935..3415a63c1063 100644
--- a/examples/research_projects/multi_subject_dreambooth/README.md
+++ b/examples/research_projects/multi_subject_dreambooth/README.md
@@ -1,6 +1,6 @@
# Multi Subject DreamBooth training
-[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
+[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
This `train_multi_subject_dreambooth.py` script shows how to implement the training procedure for one or more subjects and adapt it for stable diffusion. Note that this code is based off of the `examples/dreambooth/train_dreambooth.py` script as of 01/06/2022.
This script was added by @kopsahlong, and is not actively maintained. However, if you come across anything that could use fixing, feel free to open an issue and tag @kopsahlong.
diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
index 0f507b26d6a8..6b0ae5ba97be 100644
--- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
+++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -684,7 +684,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -763,9 +763,9 @@ def main(args):
# Parse instance and class inputs, and double check that lengths match
instance_data_dir = args.instance_data_dir.split(",")
instance_prompt = args.instance_prompt.split(",")
- assert all(
- x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]
- ), "Instance data dir and prompt inputs are not of the same length."
+ assert all(x == len(instance_data_dir) for x in [len(instance_data_dir), len(instance_prompt)]), (
+ "Instance data dir and prompt inputs are not of the same length."
+ )
if args.with_prior_preservation:
class_data_dir = args.class_data_dir.split(",")
@@ -788,9 +788,9 @@ def main(args):
negative_validation_prompts.append(None)
args.validation_negative_prompt = negative_validation_prompts
- assert num_of_validation_prompts == len(
- negative_validation_prompts
- ), "The length of negative prompts for validation is greater than the number of validation prompts."
+ assert num_of_validation_prompts == len(negative_validation_prompts), (
+ "The length of negative prompts for validation is greater than the number of validation prompts."
+ )
args.validation_inference_steps = [args.validation_inference_steps] * num_of_validation_prompts
args.validation_guidance_scale = [args.validation_guidance_scale] * num_of_validation_prompts
diff --git a/examples/research_projects/multi_subject_dreambooth_inpainting/README.md b/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
index ffd8e304efce..3412de662f58 100644
--- a/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
+++ b/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
@@ -2,7 +2,7 @@
Please note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.
-[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requieres prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
+[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requires prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))
@@ -27,7 +27,7 @@ You can build multiple datasets for every subject and upload them to the 🤗 hu
Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
```bash
-export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
export OUTPUT_DIR="path-to-save-model"
export DATASET_1="gzguevara/mr_potato_head_masked"
@@ -73,7 +73,7 @@ accelerate launch train_multi_subject_dreambooth_inpaint.py \
## 3. Results
-A [](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & baises run was performed on a A100 GPU with the following stetting:
+A [](https://wandb.ai/gzguevara/uncategorized/reports/Multi-Subject-Dreambooth-for-Inpainting--Vmlldzo2MzY5NDQ4?accessToken=y0nya2d7baguhbryxaikbfr1203amvn1jsmyl07vk122mrs7tnph037u1nqgse8t) is provided showing the training progress by every 50 steps. Note, the reported weights & biases run was performed on a A100 GPU with the following stetting:
```bash
accelerate launch train_multi_subject_dreambooth_inpaint.py \
diff --git a/examples/research_projects/multi_token_textual_inversion/README.md b/examples/research_projects/multi_token_textual_inversion/README.md
index 5e0aaf2c0575..7d80c0beee37 100644
--- a/examples/research_projects/multi_token_textual_inversion/README.md
+++ b/examples/research_projects/multi_token_textual_inversion/README.md
@@ -14,7 +14,7 @@ Feel free to add these options to your training! In practice num_vec_per_token a
## Textual Inversion fine-tuning example
-[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
## Running on Colab
@@ -60,7 +60,7 @@ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
If you have already cloned the repo, then you won't need to go through these steps.
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
index 19432142f541..3d000c8c6644 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -551,7 +552,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
@@ -830,9 +831,9 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = get_mask(tokenizer, accelerator)
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py
index ecc89f98298e..a5973e149049 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py
@@ -153,7 +153,7 @@ def parse_args():
"--use_auth_token",
action="store_true",
help=(
- "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ "Will use the token generated when running `hf auth login` (necessary to use this script with"
" private models)."
),
)
diff --git a/examples/research_projects/onnxruntime/text_to_image/README.md b/examples/research_projects/onnxruntime/text_to_image/README.md
index f1f134c576b2..f398f081663a 100644
--- a/examples/research_projects/onnxruntime/text_to_image/README.md
+++ b/examples/research_projects/onnxruntime/text_to_image/README.md
@@ -41,7 +41,7 @@ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
If you have already cloned the repo, then you won't need to go through these steps.
diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
index a886f9ab27ef..1af05e8b22da 100644
--- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -270,7 +271,7 @@ def parse_args():
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -415,7 +416,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
@@ -854,7 +855,7 @@ def collate_fn(examples):
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/research_projects/onnxruntime/textual_inversion/README.md b/examples/research_projects/onnxruntime/textual_inversion/README.md
index 0f6ec7f51186..fa6d95af30de 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/README.md
+++ b/examples/research_projects/onnxruntime/textual_inversion/README.md
@@ -1,6 +1,6 @@
## Textual Inversion fine-tuning example
-[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
## Running on Colab
@@ -46,7 +46,7 @@ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
If you have already cloned the repo, then you won't need to go through these steps.
diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
index 7f5dc8ece9fc..6044607c14b6 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
+++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -566,7 +567,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
@@ -886,9 +887,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
index 9a00f7cc4a9a..acbb77fe3ab3 100644
--- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
+++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
@@ -280,7 +280,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
index 4065a854c22d..89228983d4d8 100644
--- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
+++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -461,7 +461,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -664,7 +664,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -840,9 +840,9 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
@@ -852,7 +852,7 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
@@ -860,7 +860,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -940,7 +940,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -1067,7 +1067,7 @@ def __call__(
# compute previous image: x_t -> x_t-1
if num_inference_steps == 1:
- # For DMD one step sampling: https://arxiv.org/abs/2311.18828
+ # For DMD one step sampling: https://huggingface.co/papers/2311.18828
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
diff --git a/examples/research_projects/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py
index 67ec30da0ece..e2f1fa1bc5e9 100644
--- a/examples/research_projects/pixart/train_pixart_controlnet_hf.py
+++ b/examples/research_projects/pixart/train_pixart_controlnet_hf.py
@@ -429,7 +429,7 @@ def parse_args():
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -562,7 +562,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -612,7 +612,7 @@ def main():
# See Section 3.1. of the paper.
max_length = 120
- # For mixed precision training we cast all non-trainable weigths (vae, text_encoder) to half-precision
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
@@ -1047,7 +1047,7 @@ def collate_fn(examples):
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/research_projects/promptdiffusion/README.md b/examples/research_projects/promptdiffusion/README.md
index 33ffec312501..3df04eb7a361 100644
--- a/examples/research_projects/promptdiffusion/README.md
+++ b/examples/research_projects/promptdiffusion/README.md
@@ -4,7 +4,7 @@ From the project [page](https://zhendong-wang.github.io/prompt-diffusion.github.
"With a prompt consisting of a task-specific example pair of images and text guidance, and a new query image, Prompt Diffusion can comprehend the desired task and generate the corresponding output image on both seen (trained) and unseen (new) task types."
-For any usage questions, please refer to the [paper](https://arxiv.org/abs/2305.01115).
+For any usage questions, please refer to the [paper](https://huggingface.co/papers/2305.01115).
Prepare models by converting them from the [checkpoint](https://huggingface.co/zhendongw/prompt-diffusion)
diff --git a/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py b/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py
index 26b56a21e865..c9efcffa5bb8 100644
--- a/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py
+++ b/examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py
@@ -59,6 +59,7 @@
UnCLIPScheduler,
)
from diffusers.utils import is_accelerate_available, logging
+from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
if is_accelerate_available():
@@ -1435,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None:
- original_config_file = BytesIO(requests.get(config_url).content)
+ original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index 19c1f30d82da..233df1276563 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-# Based on [In-Context Learning Unlocked for Diffusion Models](https://arxiv.org/abs/2305.01115)
+# Based on [In-Context Learning Unlocked for Diffusion Models](https://huggingface.co/papers/2305.01115)
# Authors: Zhendong Wang, Yifan Jiang, Yadong Lu, Yelong Shen, Pengcheng He, Weizhu Chen, Zhangyang Wang, Mingyuan Zhou
# Project Page: https://zhendong-wang.github.io/prompt-diffusion.github.io/
# Code: https://github.com/Zhendong-Wang/Prompt-Diffusion
@@ -148,7 +148,7 @@ class PromptDiffusionPipeline(
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
- This pipeline also adds experimental support for [Prompt Diffusion](https://arxiv.org/abs/2305.01115).
+ This pipeline also adds experimental support for [Prompt Diffusion](https://huggingface.co/papers/2305.01115).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -177,7 +177,7 @@ class PromptDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -263,6 +263,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
@@ -271,6 +277,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -544,7 +556,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -663,8 +675,7 @@ def check_inputs(
self.check_image(image, prompt, prompt_embeds)
else:
raise ValueError(
- f"You have passed a list of images of length {len(image_pair)}."
- f"Make sure the list size equals to two."
+ f"You have passed a list of images of length {len(image_pair)}.Make sure the list size equals to two."
)
# Check `controlnet_conditioning_scale`
@@ -814,7 +825,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
- r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+ r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
@@ -878,7 +889,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -960,7 +971,7 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
diff --git a/examples/research_projects/pytorch_xla/inference/flux/README.md b/examples/research_projects/pytorch_xla/inference/flux/README.md
index 9d482e6805a3..0bbd650bb6b7 100644
--- a/examples/research_projects/pytorch_xla/inference/flux/README.md
+++ b/examples/research_projects/pytorch_xla/inference/flux/README.md
@@ -40,7 +40,7 @@ cd examples/research_projects/pytorch_xla/inference/flux/
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
Then run:
diff --git a/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py
index 9c98c9b5ff4f..35cb015a6cc7 100644
--- a/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py
+++ b/examples/research_projects/pytorch_xla/inference/flux/flux_inference.py
@@ -120,11 +120,11 @@ def _main(index, args, text_pipe, ckpt_id):
parser.add_argument("--schnell", action="store_true", help="run flux schnell instead of dev")
parser.add_argument("--width", type=int, default=1024, help="width of the image to generate")
parser.add_argument("--height", type=int, default=1024, help="height of the image to generate")
- parser.add_argument("--guidance", type=float, default=3.5, help="gauidance strentgh for dev")
+ parser.add_argument("--guidance", type=float, default=3.5, help="guidance strength for dev")
parser.add_argument("--seed", type=int, default=None, help="seed for inference")
parser.add_argument("--profile", action="store_true", help="enable profiling")
parser.add_argument("--profile-duration", type=int, default=10000, help="duration for profiling in msec.")
- parser.add_argument("--itters", type=int, default=15, help="tiems to run inference and get avg time in sec.")
+ parser.add_argument("--itters", type=int, default=15, help="items to run inference and get avg time in sec.")
args = parser.parse_args()
if args.schnell:
ckpt_id = "black-forest-labs/FLUX.1-schnell"
diff --git a/examples/research_projects/pytorch_xla/training/text_to_image/README.md b/examples/research_projects/pytorch_xla/training/text_to_image/README.md
index 06013b8a61e0..f99ab124864e 100644
--- a/examples/research_projects/pytorch_xla/training/text_to_image/README.md
+++ b/examples/research_projects/pytorch_xla/training/text_to_image/README.md
@@ -80,7 +80,7 @@ pip3 install .'
Run the following command to authenticate your token.
```bash
-huggingface-cli login
+hf auth login
```
This script only trains the unet part of the network. The VAE and text encoder
diff --git a/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
index 9719585d3dfb..021b732ad438 100644
--- a/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
+++ b/examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_xla.py
@@ -173,7 +173,7 @@ def print_loss_closure(step, loss):
if not dataloader_exception:
xm.wait_device_ops()
total_time = time.time() - last_time
- print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
+ print(f"Average step time: {total_time / (self.args.max_train_steps - measure_start_step)}")
else:
print("dataloader exception happen, skip result")
return
@@ -214,7 +214,7 @@ def step_fn(
if self.args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(self.noise_scheduler, timesteps)
@@ -342,7 +342,7 @@ def parse_args():
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--non_ema_revision",
@@ -622,7 +622,7 @@ def collate_fn(examples):
num_devices_per_host = num_devices // num_hosts
if xm.is_master_ordinal():
print("***** Running training *****")
- print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
+ print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host}")
print(
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
)
diff --git a/examples/research_projects/rdm/pipeline_rdm.py b/examples/research_projects/rdm/pipeline_rdm.py
index e84568786f50..9b696874c5d1 100644
--- a/examples/research_projects/rdm/pipeline_rdm.py
+++ b/examples/research_projects/rdm/pipeline_rdm.py
@@ -186,15 +186,15 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
@@ -202,7 +202,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -260,7 +260,7 @@ def __call__(
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
@@ -293,7 +293,7 @@ def __call__(
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
diff --git a/examples/research_projects/realfill/README.md b/examples/research_projects/realfill/README.md
index 91821031d2e0..bca5495297a5 100644
--- a/examples/research_projects/realfill/README.md
+++ b/examples/research_projects/realfill/README.md
@@ -1,6 +1,6 @@
# RealFill
-[RealFill](https://arxiv.org/abs/2309.16668) is a method to personalize text2image inpainting models like stable diffusion inpainting given just a few(1~5) images of a scene.
+[RealFill](https://huggingface.co/papers/2309.16668) is a method to personalize text2image inpainting models like stable diffusion inpainting given just a few(1~5) images of a scene.
The `train_realfill.py` script shows how to implement the training procedure for stable diffusion inpainting.
diff --git a/examples/research_projects/realfill/train_realfill.py b/examples/research_projects/realfill/train_realfill.py
index c7cc25df02b9..fd63f71b5fce 100644
--- a/examples/research_projects/realfill/train_realfill.py
+++ b/examples/research_projects/realfill/train_realfill.py
@@ -535,7 +535,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -759,7 +759,7 @@ def load_model_hook(models, input_dir):
unet, text_encoder, optimizer, train_dataloader
)
- # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
diff --git a/examples/research_projects/sana/README.md b/examples/research_projects/sana/README.md
new file mode 100644
index 000000000000..933f32e3f983
--- /dev/null
+++ b/examples/research_projects/sana/README.md
@@ -0,0 +1,95 @@
+# Training SANA Sprint Diffuser
+
+This README explains how to use the provided bash script commands to download a pre-trained teacher diffuser model and train it on a specific dataset, following the [SANA Sprint methodology](https://huggingface.co/papers/2503.09641).
+
+
+## Setup
+
+### 1. Define the local paths
+
+Set a variable for your desired output directory. This directory will store the downloaded model and the training checkpoints/results.
+
+```bash
+your_local_path='output' # Or any other path you prefer
+mkdir -p $your_local_path # Create the directory if it doesn't exist
+```
+
+### 2. Download the pre-trained model
+
+Download the SANA Sprint teacher model from Hugging Face Hub. The script uses the 1.6B parameter model.
+
+```bash
+hf download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
+```
+
+*(Optional: You can also download the 0.6B model by replacing the model name: `Efficient-Large-Model/Sana_Sprint_0.6B_1024px_teacher_diffusers`)*
+
+### 3. Acquire the dataset shards
+
+The training script in this example uses specific `.parquet` shards from a randomly selected `brivangl/midjourney-v6-llava` dataset instead of downloading the entire dataset automatically via `dataset_name`.
+
+The script specifically uses these three files:
+* `data/train_000.parquet`
+* `data/train_001.parquet`
+* `data/train_002.parquet`
+
+
+
+You can either:
+
+Let the script download the dataset automatically during first run
+
+Or download it manually
+
+**Note:** The full `brivangl/midjourney-v6-llava` dataset is much larger and contains many more shards. This script example explicitly trains *only* on the three specified shards.
+
+## Usage
+
+Once the model is downloaded, you can run the training script.
+
+```bash
+
+your_local_path='output' # Ensure this variable is set
+
+python train_sana_sprint_diffusers.py \
+ --pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
+ --output_dir=$your_local_path \
+ --mixed_precision=bf16 \
+ --resolution=1024 \
+ --learning_rate=1e-6 \
+ --max_train_steps=30000 \
+ --dataloader_num_workers=8 \
+ --dataset_name='brivangl/midjourney-v6-llava' \
+ --file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
+ --checkpointing_steps=500 --checkpoints_total_limit=10 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --seed=453645634 \
+ --train_largest_timestep \
+ --misaligned_pairs_D \
+ --gradient_checkpointing \
+ --resume_from_checkpoint="latest" \
+```
+
+### Explanation of parameters
+
+* `--pretrained_model_name_or_path`: Path to the downloaded pre-trained model directory.
+* `--output_dir`: Directory where training logs, checkpoints, and the final model will be saved.
+* `--mixed_precision`: Use BF16 mixed precision for training, which can save memory and speed up training on compatible hardware.
+* `--resolution`: The image resolution used for training (1024x1024).
+* `--learning_rate`: The learning rate for the optimizer.
+* `--max_train_steps`: The total number of training steps to perform.
+* `--dataloader_num_workers`: Number of worker processes for loading data. Increase for faster data loading if your CPU and disk can handle it.
+* `--dataset_name`: The name of the dataset on Hugging Face Hub (`brivangl/midjourney-v6-llava`).
+* `--file_path`: **Specifies the local paths to the dataset shards to be used for training.** In this case, `data/train_000.parquet`, `data/train_001.parquet`, and `data/train_002.parquet`.
+* `--checkpointing_steps`: Save a training checkpoint every X steps.
+* `--checkpoints_total_limit`: Maximum number of checkpoints to keep. Older checkpoints will be deleted.
+* `--train_batch_size`: The batch size per GPU.
+* `--gradient_accumulation_steps`: Number of steps to accumulate gradients before performing an optimizer step.
+* `--seed`: Random seed for reproducibility.
+* `--train_largest_timestep`: A specific training strategy focusing on larger timesteps.
+* `--misaligned_pairs_D`: Another specific training strategy to add misaligned image-text pairs as fake data for GAN.
+* `--gradient_checkpointing`: Enable gradient checkpointing to save GPU memory.
+* `--resume_from_checkpoint`: Allows resuming training from the latest saved checkpoint in the `--output_dir`.
+
+
diff --git a/examples/research_projects/sana/train_sana_sprint_diffusers.py b/examples/research_projects/sana/train_sana_sprint_diffusers.py
new file mode 100644
index 000000000000..d127fee5fd0d
--- /dev/null
+++ b/examples/research_projects/sana/train_sana_sprint_diffusers.py
@@ -0,0 +1,1782 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 Sana-Sprint team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import io
+import logging
+import math
+import os
+import shutil
+from pathlib import Path
+from typing import Callable, Optional
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms as T
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, DistributedType, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from safetensors.torch import load_file
+from torch.nn.utils.spectral_norm import SpectralNorm
+from torch.utils.data import DataLoader, Dataset
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, Gemma2Model
+
+import diffusers
+from diffusers import (
+ AutoencoderDC,
+ SanaPipeline,
+ SanaSprintPipeline,
+ SanaTransformer2DModel,
+)
+from diffusers.models.attention_processor import Attention
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ free_memory,
+)
+from diffusers.utils import (
+ check_min_version,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.33.0.dev0")
+
+logger = get_logger(__name__)
+
+if is_torch_npu_available():
+ torch.npu.config.allow_internal_format = False
+
+COMPLEX_HUMAN_INSTRUCTION = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+]
+
+
+class SanaVanillaAttnProcessor:
+ r"""
+ Processor for implementing scaled dot-product attention to support JVP calculation during training.
+ """
+
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def scaled_dot_product_attention(
+ query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
+ ) -> torch.Tensor:
+ B, H, L, S = *query.size()[:-1], key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
+ attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+ return attn_weight @ value
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = self.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class Text2ImageDataset(Dataset):
+ """
+ A PyTorch Dataset class for loading text-image pairs from a HuggingFace dataset.
+ This dataset is designed for text-to-image generation tasks.
+ Args:
+ hf_dataset (datasets.Dataset):
+ A HuggingFace dataset containing 'image' (bytes) and 'llava' (text) fields. Note that 'llava' is the field name for text descriptions in this specific dataset - you may need to adjust this key if using a different HuggingFace dataset with a different text field name.
+ resolution (int, optional): Target resolution for image resizing. Defaults to 1024.
+ Returns:
+ dict: A dictionary containing:
+ - 'text': The text description (str)
+ - 'image': The processed image tensor (torch.Tensor) of shape [3, resolution, resolution]
+ """
+
+ def __init__(self, hf_dataset, resolution=1024):
+ self.dataset = hf_dataset
+ self.transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB")),
+ T.Resize(resolution), # Image.BICUBIC
+ T.CenterCrop(resolution),
+ T.ToTensor(),
+ T.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, idx):
+ item = self.dataset[idx]
+ text = item["llava"]
+ image_bytes = item["image"]
+
+ # Convert bytes to PIL Image
+ image = Image.open(io.BytesIO(image_bytes))
+
+ image_tensor = self.transform(image)
+
+ return {"text": text, "image": image_tensor}
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Sana Sprint - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} Sana Sprint weights for {base_model}.
+
+The weights were trained using [Sana-Sprint](https://nvlabs.github.io/Sana/Sprint/).
+
+## License
+
+TODO
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "sana-sprint",
+ "sana-sprint-diffusers",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.enable_vae_tiling:
+ pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)
+
+ pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=300,
+ help="Maximum sequence length to use with with the Gemma model",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sana-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Image Processing----
+ parser.add_argument("--file_path", nargs="+", required=True, help="List of parquet files (space-separated)")
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--use_fix_crop_and_size",
+ action="store_true",
+ help="Whether or not to use the fixed crop and size for the teacher model.",
+ default=False,
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.2, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.6, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_mean_discriminator", type=float, default=-0.6, help="Logit mean for discriminator timestep sampling"
+ )
+ parser.add_argument(
+ "--logit_std_discriminator", type=float, default=1.0, help="Logit std for discriminator timestep sampling"
+ )
+ parser.add_argument("--ladd_multi_scale", action="store_true", help="Whether to use multi-scale discriminator")
+ parser.add_argument(
+ "--head_block_ids",
+ type=int,
+ nargs="+",
+ default=[2, 8, 14, 19],
+ help="Specify which transformer blocks to use for discriminator heads",
+ )
+ parser.add_argument("--adv_lambda", type=float, default=0.5, help="Weighting coefficient for adversarial loss")
+ parser.add_argument("--scm_lambda", type=float, default=1.0, help="Weighting coefficient for SCM loss")
+ parser.add_argument("--gradient_clip", type=float, default=0.1, help="Threshold for gradient clipping")
+ parser.add_argument(
+ "--sigma_data", type=float, default=0.5, help="Standard deviation of data distribution is supposed to be 0.5"
+ )
+ parser.add_argument(
+ "--tangent_warmup_steps", type=int, default=4000, help="Number of warmup steps for tangent vectors"
+ )
+ parser.add_argument(
+ "--guidance_embeds_scale", type=float, default=0.1, help="Scaling factor for guidance embeddings"
+ )
+ parser.add_argument(
+ "--scm_cfg_scale", type=float, nargs="+", default=[4, 4.5, 5], help="Range for classifier-free guidance scale"
+ )
+ parser.add_argument(
+ "--train_largest_timestep", action="store_true", help="Whether to enable special training for large timesteps"
+ )
+ parser.add_argument("--largest_timestep", type=float, default=1.57080, help="Maximum timestep value")
+ parser.add_argument(
+ "--largest_timestep_prob", type=float, default=0.5, help="Sampling probability for large timesteps"
+ )
+ parser.add_argument(
+ "--misaligned_pairs_D", action="store_true", help="Add misaligned sample pairs for discriminator"
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--offload",
+ action="store_true",
+ help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, fn: Callable):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return (self.fn(x) + x) / np.sqrt(2)
+
+
+class SpectralConv1d(nn.Conv1d):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ SpectralNorm.apply(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12)
+
+
+class BatchNormLocal(nn.Module):
+ def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 8, eps: float = 1e-5):
+ super().__init__()
+ self.virtual_bs = virtual_bs
+ self.eps = eps
+ self.affine = affine
+
+ if self.affine:
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shape = x.size()
+
+ # Reshape batch into groups.
+ G = np.ceil(x.size(0) / self.virtual_bs).astype(int)
+ x = x.view(G, -1, x.size(-2), x.size(-1))
+
+ # Calculate stats.
+ mean = x.mean([1, 3], keepdim=True)
+ var = x.var([1, 3], keepdim=True, unbiased=False)
+ x = (x - mean) / (torch.sqrt(var + self.eps))
+
+ if self.affine:
+ x = x * self.weight[None, :, None] + self.bias[None, :, None]
+
+ return x.view(shape)
+
+
+def make_block(channels: int, kernel_size: int) -> nn.Module:
+ return nn.Sequential(
+ SpectralConv1d(
+ channels,
+ channels,
+ kernel_size=kernel_size,
+ padding=kernel_size // 2,
+ padding_mode="circular",
+ ),
+ BatchNormLocal(channels),
+ nn.LeakyReLU(0.2, True),
+ )
+
+
+# Adapted from https://github.com/autonomousvision/stylegan-t/blob/main/networks/discriminator.py
+class DiscHead(nn.Module):
+ def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64):
+ super().__init__()
+ self.channels = channels
+ self.c_dim = c_dim
+ self.cmap_dim = cmap_dim
+
+ self.main = nn.Sequential(
+ make_block(channels, kernel_size=1), ResidualBlock(make_block(channels, kernel_size=9))
+ )
+
+ if self.c_dim > 0:
+ self.cmapper = nn.Linear(self.c_dim, cmap_dim)
+ self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0)
+ else:
+ self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0)
+
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
+ h = self.main(x)
+ out = self.cls(h)
+
+ if self.c_dim > 0:
+ cmap = self.cmapper(c).unsqueeze(-1)
+ out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
+
+ return out
+
+
+class SanaMSCMDiscriminator(nn.Module):
+ def __init__(self, pretrained_model, is_multiscale=False, head_block_ids=None):
+ super().__init__()
+ self.transformer = pretrained_model
+ self.transformer.requires_grad_(False)
+
+ if head_block_ids is None or len(head_block_ids) == 0:
+ self.block_hooks = {2, 8, 14, 20, 27} if is_multiscale else {self.transformer.depth - 1}
+ else:
+ self.block_hooks = head_block_ids
+
+ heads = []
+ for i in range(len(self.block_hooks)):
+ heads.append(DiscHead(self.transformer.hidden_size, 0, 0))
+ self.heads = nn.ModuleList(heads)
+
+ def get_head_inputs(self):
+ return self.head_inputs
+
+ def forward(self, hidden_states, timestep, encoder_hidden_states=None, **kwargs):
+ feat_list = []
+ self.head_inputs = []
+
+ def get_features(module, input, output):
+ feat_list.append(output)
+ return output
+
+ hooks = []
+ for i, block in enumerate(self.transformer.transformer_blocks):
+ if i in self.block_hooks:
+ hooks.append(block.register_forward_hook(get_features))
+
+ self.transformer(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ encoder_hidden_states=encoder_hidden_states,
+ return_logvar=False,
+ **kwargs,
+ )
+
+ for hook in hooks:
+ hook.remove()
+
+ res_list = []
+ for feat, head in zip(feat_list, self.heads):
+ B, N, C = feat.shape
+ feat = feat.transpose(1, 2) # [B, C, N]
+ self.head_inputs.append(feat)
+ res_list.append(head(feat, None).reshape(feat.shape[0], -1))
+
+ concat_res = torch.cat(res_list, dim=1)
+
+ return concat_res
+
+ @property
+ def model(self):
+ return self.transformer
+
+ def save_pretrained(self, path):
+ torch.save(self.state_dict(), path)
+
+
+class DiscHeadModel:
+ def __init__(self, disc):
+ self.disc = disc
+
+ def state_dict(self):
+ return {name: param for name, param in self.disc.state_dict().items() if not name.startswith("transformer.")}
+
+ def __getattr__(self, name):
+ return getattr(self.disc, name)
+
+
+class SanaTrigFlow(SanaTransformer2DModel):
+ def __init__(self, original_model, guidance=False):
+ self.__dict__ = original_model.__dict__
+ self.hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
+ self.guidance = guidance
+ if self.guidance:
+ hidden_size = self.config.num_attention_heads * self.config.attention_head_dim
+ self.logvar_linear = torch.nn.Linear(hidden_size, 1)
+ torch.nn.init.xavier_uniform_(self.logvar_linear.weight)
+ torch.nn.init.constant_(self.logvar_linear.bias, 0)
+
+ def forward(
+ self, hidden_states, encoder_hidden_states, timestep, guidance=None, jvp=False, return_logvar=False, **kwargs
+ ):
+ batch_size = hidden_states.shape[0]
+ latents = hidden_states
+ prompt_embeds = encoder_hidden_states
+ t = timestep
+
+ # TrigFlow --> Flow Transformation
+ timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
+ latents_model_input = latents
+
+ flow_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
+
+ flow_timestep_expanded = flow_timestep.view(-1, 1, 1, 1)
+ latent_model_input = latents_model_input * torch.sqrt(
+ flow_timestep_expanded**2 + (1 - flow_timestep_expanded) ** 2
+ )
+ latent_model_input = latent_model_input.to(prompt_embeds.dtype)
+
+ # forward in original flow
+
+ if jvp and self.gradient_checkpointing:
+ self.gradient_checkpointing = False
+ model_out = super().forward(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=flow_timestep,
+ guidance=guidance,
+ **kwargs,
+ )[0]
+ self.gradient_checkpointing = True
+ else:
+ model_out = super().forward(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=flow_timestep,
+ guidance=guidance,
+ **kwargs,
+ )[0]
+
+ # Flow --> TrigFlow Transformation
+ trigflow_model_out = (
+ (1 - 2 * flow_timestep_expanded) * latent_model_input
+ + (1 - 2 * flow_timestep_expanded + 2 * flow_timestep_expanded**2) * model_out
+ ) / torch.sqrt(flow_timestep_expanded**2 + (1 - flow_timestep_expanded) ** 2)
+
+ if self.guidance and guidance is not None:
+ timestep, embedded_timestep = self.time_embed(
+ timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ timestep, embedded_timestep = self.time_embed(
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ if return_logvar:
+ logvar = self.logvar_linear(embedded_timestep)
+ return trigflow_model_out, logvar
+
+ return (trigflow_model_out,)
+
+
+def compute_density_for_timestep_sampling_scm(batch_size: int, logit_mean: float = None, logit_std: float = None):
+ """Compute the density for sampling the timesteps when doing Sana-Sprint training."""
+ sigma = torch.randn(batch_size, device="cpu")
+ sigma = (sigma * logit_std + logit_mean).exp()
+ u = torch.atan(sigma / 0.5) # TODO: 0.5 should be a hyper-parameter
+
+ return u
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `hf auth login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+
+ # Load scheduler and models
+ text_encoder = Gemma2Model.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderDC.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ ori_transformer = SanaTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ guidance_embeds=True,
+ )
+ ori_transformer.set_attn_processor(SanaVanillaAttnProcessor())
+
+ ori_transformer_no_guide = SanaTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ guidance_embeds=False,
+ )
+
+ original_state_dict = load_file(
+ f"{args.pretrained_model_name_or_path}/transformer/diffusion_pytorch_model.safetensors"
+ )
+
+ param_mapping = {
+ "time_embed.emb.timestep_embedder.linear_1.weight": "time_embed.timestep_embedder.linear_1.weight",
+ "time_embed.emb.timestep_embedder.linear_1.bias": "time_embed.timestep_embedder.linear_1.bias",
+ "time_embed.emb.timestep_embedder.linear_2.weight": "time_embed.timestep_embedder.linear_2.weight",
+ "time_embed.emb.timestep_embedder.linear_2.bias": "time_embed.timestep_embedder.linear_2.bias",
+ }
+
+ for src_key, dst_key in param_mapping.items():
+ if src_key in original_state_dict:
+ ori_transformer.load_state_dict({dst_key: original_state_dict[src_key]}, strict=False, assign=True)
+
+ guidance_embedder_module = ori_transformer.time_embed.guidance_embedder
+
+ zero_state_dict = {}
+
+ target_device = accelerator.device
+ param_w1 = guidance_embedder_module.linear_1.weight
+ zero_state_dict["linear_1.weight"] = torch.zeros(param_w1.shape, device=target_device)
+ param_b1 = guidance_embedder_module.linear_1.bias
+ zero_state_dict["linear_1.bias"] = torch.zeros(param_b1.shape, device=target_device)
+ param_w2 = guidance_embedder_module.linear_2.weight
+ zero_state_dict["linear_2.weight"] = torch.zeros(param_w2.shape, device=target_device)
+ param_b2 = guidance_embedder_module.linear_2.bias
+ zero_state_dict["linear_2.bias"] = torch.zeros(param_b2.shape, device=target_device)
+ guidance_embedder_module.load_state_dict(zero_state_dict, strict=False, assign=True)
+
+ transformer = SanaTrigFlow(ori_transformer, guidance=True).train()
+ pretrained_model = SanaTrigFlow(ori_transformer_no_guide, guidance=False).eval()
+
+ disc = SanaMSCMDiscriminator(
+ pretrained_model,
+ is_multiscale=args.ladd_multi_scale,
+ head_block_ids=args.head_block_ids,
+ ).train()
+
+ transformer.requires_grad_(True)
+ pretrained_model.requires_grad_(False)
+ disc.model.requires_grad_(False)
+ disc.heads.requires_grad_(True)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ # VAE should always be kept in fp32 for SANA (?)
+ vae.to(accelerator.device, dtype=torch.float32)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ pretrained_model.to(accelerator.device, dtype=weight_dtype)
+ disc.to(accelerator.device, dtype=weight_dtype)
+ # because Gemma2 is particularly suited for bfloat16.
+ text_encoder.to(dtype=torch.bfloat16)
+
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ for block in transformer.transformer_blocks:
+ block.attn2.set_use_npu_flash_attention(True)
+ for block in pretrained_model.transformer_blocks:
+ block.attn2.set_use_npu_flash_attention(True)
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
+ # Initialize a text encoding pipeline and keep it to CPU for now.
+ text_encoding_pipeline = SanaPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=None,
+ transformer=None,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ torch_dtype=torch.bfloat16,
+ )
+ text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ unwrapped_model = unwrap_model(model)
+ # Handle transformer model
+ if isinstance(unwrapped_model, type(unwrap_model(transformer))):
+ model = unwrapped_model
+ model.save_pretrained(os.path.join(output_dir, "transformer"))
+ # Handle discriminator model (only save heads)
+ elif isinstance(unwrapped_model, type(unwrap_model(disc))):
+ # Save only the heads
+ torch.save(unwrapped_model.heads.state_dict(), os.path.join(output_dir, "disc_heads.pt"))
+ else:
+ raise ValueError(f"unexpected save model: {unwrapped_model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+ disc_ = None
+
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
+ unwrapped_model = unwrap_model(model)
+
+ if isinstance(unwrapped_model, type(unwrap_model(transformer))):
+ transformer_ = model # noqa: F841
+ elif isinstance(unwrapped_model, type(unwrap_model(disc))):
+ # Load only the heads
+ heads_state_dict = torch.load(os.path.join(input_dir, "disc_heads.pt"))
+ unwrapped_model.heads.load_state_dict(heads_state_dict)
+ disc_ = model # noqa: F841
+ else:
+ raise ValueError(f"unexpected save model: {unwrapped_model.__class__}")
+
+ else:
+ # DeepSpeed case
+ transformer_ = SanaTransformer2DModel.from_pretrained(input_dir, subfolder="transformer") # noqa: F841
+ disc_heads_state_dict = torch.load(os.path.join(input_dir, "disc_heads.pt")) # noqa: F841
+ # You'll need to handle how to load the heads in DeepSpeed case
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimization parameters
+ optimizer_G = optimizer_class(
+ transformer.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ optimizer_D = optimizer_class(
+ disc.heads.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ hf_dataset = load_dataset(
+ args.dataset_name,
+ data_files=args.file_path,
+ split="train",
+ )
+
+ train_dataset = Text2ImageDataset(
+ hf_dataset=hf_dataset,
+ resolution=args.resolution,
+ )
+
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ pin_memory=True,
+ persistent_workers=True,
+ shuffle=True,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer_G,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, pretrained_model, disc, optimizer_G, optimizer_D, train_dataloader, lr_scheduler = (
+ accelerator.prepare(
+ transformer, pretrained_model, disc, optimizer_G, optimizer_D, train_dataloader, lr_scheduler
+ )
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "sana-sprint"
+ config = {
+ k: str(v) if not isinstance(v, (int, float, str, bool, torch.Tensor)) else v for k, v in vars(args).items()
+ }
+ accelerator.init_trackers(tracker_name, config=config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ phase = "G"
+ vae_config_scaling_factor = vae.config.scaling_factor
+ sigma_data = args.sigma_data
+ negative_prompt = [""] * args.train_batch_size
+ negative_prompt_embeds, negative_prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
+ prompt=negative_prompt,
+ complex_human_instruction=False,
+ do_classifier_free_guidance=False,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+ disc.train()
+
+ for step, batch in enumerate(train_dataloader):
+ # text encoding
+ prompts = batch["text"]
+ with torch.no_grad():
+ prompt_embeds, prompt_attention_mask, _, _ = text_encoding_pipeline.encode_prompt(
+ prompts, complex_human_instruction=COMPLEX_HUMAN_INSTRUCTION, do_classifier_free_guidance=False
+ )
+
+ # Convert images to latent space
+ vae = vae.to(accelerator.device)
+ pixel_values = batch["image"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent
+ model_input = model_input * vae_config_scaling_factor * sigma_data
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input) * sigma_data
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling_scm(
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ ).to(accelerator.device)
+
+ # Add noise according to TrigFlow.
+ # zt = cos(t) * x + sin(t) * noise
+ t = u.view(-1, 1, 1, 1)
+ noisy_model_input = torch.cos(t) * model_input + torch.sin(t) * noise
+
+ scm_cfg_scale = torch.tensor(
+ np.random.choice(args.scm_cfg_scale, size=bsz, replace=True),
+ device=accelerator.device,
+ )
+
+ def model_wrapper(scaled_x_t, t):
+ pred, logvar = accelerator.unwrap_model(transformer)(
+ hidden_states=scaled_x_t,
+ timestep=t.flatten(),
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale),
+ jvp=True,
+ return_logvar=True,
+ )
+ return pred, logvar
+
+ if phase == "G":
+ transformer.train()
+ disc.eval()
+ models_to_accumulate = [transformer]
+ with accelerator.accumulate(models_to_accumulate):
+ with torch.no_grad():
+ cfg_x_t = torch.cat([noisy_model_input, noisy_model_input], dim=0)
+ cfg_t = torch.cat([t, t], dim=0)
+ cfg_y = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ cfg_y_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ cfg_pretrain_pred = pretrained_model(
+ hidden_states=(cfg_x_t / sigma_data),
+ timestep=cfg_t.flatten(),
+ encoder_hidden_states=cfg_y,
+ encoder_attention_mask=cfg_y_mask,
+ )[0]
+
+ cfg_dxt_dt = sigma_data * cfg_pretrain_pred
+
+ dxt_dt_uncond, dxt_dt = cfg_dxt_dt.chunk(2)
+
+ scm_cfg_scale = scm_cfg_scale.view(-1, 1, 1, 1)
+ dxt_dt = dxt_dt_uncond + scm_cfg_scale * (dxt_dt - dxt_dt_uncond)
+
+ v_x = torch.cos(t) * torch.sin(t) * dxt_dt / sigma_data
+ v_t = torch.cos(t) * torch.sin(t)
+
+ # Adapt from https://github.com/xandergos/sCM-mnist/blob/master/train_consistency.py
+ with torch.no_grad():
+ F_theta, F_theta_grad, logvar = torch.func.jvp(
+ model_wrapper, (noisy_model_input / sigma_data, t), (v_x, v_t), has_aux=True
+ )
+
+ F_theta, logvar = transformer(
+ hidden_states=(noisy_model_input / sigma_data),
+ timestep=t.flatten(),
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale),
+ return_logvar=True,
+ )
+
+ logvar = logvar.view(-1, 1, 1, 1)
+ F_theta_grad = F_theta_grad.detach()
+ F_theta_minus = F_theta.detach()
+
+ # Warmup steps
+ r = min(1, global_step / args.tangent_warmup_steps)
+
+ # Calculate gradient g using JVP rearrangement
+ g = -torch.cos(t) * torch.cos(t) * (sigma_data * F_theta_minus - dxt_dt)
+ second_term = -r * (torch.cos(t) * torch.sin(t) * noisy_model_input + sigma_data * F_theta_grad)
+ g = g + second_term
+
+ # Tangent normalization
+ g_norm = torch.linalg.vector_norm(g, dim=(1, 2, 3), keepdim=True)
+ g = g / (g_norm + 0.1) # 0.1 is the constant c, can be modified but 0.1 was used in the paper
+
+ sigma = torch.tan(t) * sigma_data
+ weight = 1 / sigma
+
+ l2_loss = torch.square(F_theta - F_theta_minus - g)
+
+ # Calculate loss with normalization factor
+ loss = (weight / torch.exp(logvar)) * l2_loss + logvar
+
+ loss = loss.mean()
+
+ loss_no_logvar = weight * torch.square(F_theta - F_theta_minus - g)
+ loss_no_logvar = loss_no_logvar.mean()
+ g_norm = g_norm.mean()
+
+ pred_x_0 = torch.cos(t) * noisy_model_input - torch.sin(t) * F_theta * sigma_data
+
+ if args.train_largest_timestep:
+ pred_x_0.detach()
+ u = compute_density_for_timestep_sampling_scm(
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ ).to(accelerator.device)
+ t_new = u.view(-1, 1, 1, 1)
+
+ random_mask = torch.rand_like(t_new) < args.largest_timestep_prob
+
+ t_new = torch.where(random_mask, torch.full_like(t_new, args.largest_timestep), t_new)
+ z_new = torch.randn_like(model_input) * sigma_data
+ x_t_new = torch.cos(t_new) * model_input + torch.sin(t_new) * z_new
+
+ F_theta = transformer(
+ hidden_states=(x_t_new / sigma_data),
+ timestep=t_new.flatten(),
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale),
+ return_logvar=False,
+ jvp=False,
+ )[0]
+
+ pred_x_0 = torch.cos(t_new) * x_t_new - torch.sin(t_new) * F_theta * sigma_data
+
+ # Sample timesteps for discriminator
+ timesteps_D = compute_density_for_timestep_sampling_scm(
+ batch_size=bsz,
+ logit_mean=args.logit_mean_discriminator,
+ logit_std=args.logit_std_discriminator,
+ ).to(accelerator.device)
+ t_D = timesteps_D.view(-1, 1, 1, 1)
+
+ # Add noise to predicted x0
+ z_D = torch.randn_like(model_input) * sigma_data
+ noised_predicted_x0 = torch.cos(t_D) * pred_x_0 + torch.sin(t_D) * z_D
+
+ # Calculate adversarial loss
+ pred_fake = disc(
+ hidden_states=(noised_predicted_x0 / sigma_data),
+ timestep=t_D.flatten(),
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ )
+ adv_loss = -torch.mean(pred_fake)
+
+ # Total loss = sCM loss + LADD loss
+
+ total_loss = args.scm_lambda * loss + adv_loss * args.adv_lambda
+
+ total_loss = total_loss / args.gradient_accumulation_steps
+
+ accelerator.backward(total_loss)
+
+ if accelerator.sync_gradients:
+ grad_norm = accelerator.clip_grad_norm_(transformer.parameters(), args.gradient_clip)
+ if torch.logical_or(grad_norm.isnan(), grad_norm.isinf()):
+ optimizer_G.zero_grad(set_to_none=True)
+ optimizer_D.zero_grad(set_to_none=True)
+ logger.warning("NaN or Inf detected in grad_norm, skipping iteration...")
+ continue
+
+ # switch phase to D
+ phase = "D"
+
+ optimizer_G.step()
+ lr_scheduler.step()
+ optimizer_G.zero_grad(set_to_none=True)
+
+ elif phase == "D":
+ transformer.eval()
+ disc.train()
+ models_to_accumulate = [disc]
+ with accelerator.accumulate(models_to_accumulate):
+ with torch.no_grad():
+ scm_cfg_scale = torch.tensor(
+ np.random.choice(args.scm_cfg_scale, size=bsz, replace=True),
+ device=accelerator.device,
+ )
+
+ if args.train_largest_timestep:
+ random_mask = torch.rand_like(t) < args.largest_timestep_prob
+ t = torch.where(random_mask, torch.full_like(t, args.largest_timestep_prob), t)
+
+ z_new = torch.randn_like(model_input) * sigma_data
+ noisy_model_input = torch.cos(t) * model_input + torch.sin(t) * z_new
+ # here
+ F_theta = transformer(
+ hidden_states=(noisy_model_input / sigma_data),
+ timestep=t.flatten(),
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ guidance=(scm_cfg_scale.flatten() * args.guidance_embeds_scale),
+ return_logvar=False,
+ jvp=False,
+ )[0]
+ pred_x_0 = torch.cos(t) * noisy_model_input - torch.sin(t) * F_theta * sigma_data
+
+ # Sample timesteps for fake and real samples
+ timestep_D_fake = compute_density_for_timestep_sampling_scm(
+ batch_size=bsz,
+ logit_mean=args.logit_mean_discriminator,
+ logit_std=args.logit_std_discriminator,
+ ).to(accelerator.device)
+ timesteps_D_real = timestep_D_fake
+
+ t_D_fake = timestep_D_fake.view(-1, 1, 1, 1)
+ t_D_real = timesteps_D_real.view(-1, 1, 1, 1)
+
+ # Add noise to predicted x0 and real x0
+ z_D_fake = torch.randn_like(model_input) * sigma_data
+ z_D_real = torch.randn_like(model_input) * sigma_data
+ noised_predicted_x0 = torch.cos(t_D_fake) * pred_x_0 + torch.sin(t_D_fake) * z_D_fake
+ noised_latents = torch.cos(t_D_real) * model_input + torch.sin(t_D_real) * z_D_real
+
+ # Add misaligned pairs if enabled and batch size > 1
+ if args.misaligned_pairs_D and bsz > 1:
+ # Create shifted pairs
+ shifted_x0 = torch.roll(model_input, 1, 0)
+ timesteps_D_shifted = compute_density_for_timestep_sampling_scm(
+ batch_size=bsz,
+ logit_mean=args.logit_mean_discriminator,
+ logit_std=args.logit_std_discriminator,
+ ).to(accelerator.device)
+ t_D_shifted = timesteps_D_shifted.view(-1, 1, 1, 1)
+
+ # Add noise to shifted pairs
+ z_D_shifted = torch.randn_like(shifted_x0) * sigma_data
+ noised_shifted_x0 = torch.cos(t_D_shifted) * shifted_x0 + torch.sin(t_D_shifted) * z_D_shifted
+
+ # Concatenate with original noised samples
+ noised_predicted_x0 = torch.cat([noised_predicted_x0, noised_shifted_x0], dim=0)
+ t_D_fake = torch.cat([t_D_fake, t_D_shifted], dim=0)
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # Calculate D loss
+
+ pred_fake = disc(
+ hidden_states=(noised_predicted_x0 / sigma_data),
+ timestep=t_D_fake.flatten(),
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ )
+ pred_true = disc(
+ hidden_states=(noised_latents / sigma_data),
+ timestep=t_D_real.flatten(),
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ )
+
+ # hinge loss
+ loss_real = torch.mean(F.relu(1.0 - pred_true))
+ loss_gen = torch.mean(F.relu(1.0 + pred_fake))
+ loss_D = 0.5 * (loss_real + loss_gen)
+
+ loss_D = loss_D / args.gradient_accumulation_steps
+
+ accelerator.backward(loss_D)
+
+ if accelerator.sync_gradients:
+ grad_norm = accelerator.clip_grad_norm_(disc.parameters(), args.gradient_clip)
+ if torch.logical_or(grad_norm.isnan(), grad_norm.isinf()):
+ optimizer_G.zero_grad(set_to_none=True)
+ optimizer_D.zero_grad(set_to_none=True)
+ logger.warning("NaN or Inf detected in grad_norm, skipping iteration...")
+ continue
+
+ # switch back to phase G and add global step by one.
+ phase = "G"
+
+ optimizer_D.step()
+ optimizer_D.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {
+ "scm_loss": loss.detach().item(),
+ "adv_loss": adv_loss.detach().item(),
+ "lr": lr_scheduler.get_last_lr()[0],
+ }
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = SanaSprintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=torch.float32,
+ )
+ pipeline_args = {
+ "prompt": args.validation_prompt,
+ "complex_human_instruction": COMPLEX_HUMAN_INSTRUCTION,
+ }
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ )
+ free_memory()
+
+ images = None
+ del pipeline
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+
+ # Save discriminator heads
+ disc = unwrap_model(disc)
+ disc_heads_state_dict = disc.heads.state_dict()
+
+ # Save transformer model
+ transformer.save_pretrained(os.path.join(args.output_dir, "transformer"))
+
+ # Save discriminator heads
+ torch.save(disc_heads_state_dict, os.path.join(args.output_dir, "disc_heads.pt"))
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = SanaSprintPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=torch.float32,
+ )
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {
+ "prompt": args.validation_prompt,
+ "complex_human_instruction": COMPLEX_HUMAN_INSTRUCTION,
+ }
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ images = None
+ del pipeline
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/sana/train_sana_sprint_diffusers.sh b/examples/research_projects/sana/train_sana_sprint_diffusers.sh
new file mode 100644
index 000000000000..acd49ad67f5a
--- /dev/null
+++ b/examples/research_projects/sana/train_sana_sprint_diffusers.sh
@@ -0,0 +1,26 @@
+your_local_path='output'
+
+hf download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
+
+# or Sana_Sprint_0.6B_1024px_teacher_diffusers
+
+python train_sana_sprint_diffusers.py \
+ --pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
+ --output_dir=$your_local_path \
+ --mixed_precision=bf16 \
+ --resolution=1024 \
+ --learning_rate=1e-6 \
+ --max_train_steps=30000 \
+ --dataloader_num_workers=8 \
+ --dataset_name='brivangl/midjourney-v6-llava' \
+ --file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
+ --checkpointing_steps=500 --checkpoints_total_limit=10 \
+ --train_batch_size=1 \
+ --gradient_accumulation_steps=1 \
+ --seed=453645634 \
+ --train_largest_timestep \
+ --misaligned_pairs_D \
+ --gradient_checkpointing \
+ --resume_from_checkpoint="latest" \
+
+
diff --git a/examples/research_projects/scheduled_huber_loss_training/README.md b/examples/research_projects/scheduled_huber_loss_training/README.md
index 239f94ba1005..a587b076669a 100644
--- a/examples/research_projects/scheduled_huber_loss_training/README.md
+++ b/examples/research_projects/scheduled_huber_loss_training/README.md
@@ -1,6 +1,6 @@
# Scheduled Pseudo-Huber Loss for Diffusers
-These are the modifications of to include the possibility of training text2image models with Scheduled Pseudo Huber loss, introduced in https://arxiv.org/abs/2403.16728. (https://github.com/kabachuha/SPHL-for-stable-diffusion)
+These are the modifications of to include the possibility of training text2image models with Scheduled Pseudo Huber loss, introduced in https://huggingface.co/papers/2403.16728. (https://github.com/kabachuha/SPHL-for-stable-diffusion)
## Why this might be useful?
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
index 26caba5a42c1..c50405636982 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -536,7 +537,7 @@ def parse_args(input_args=None):
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--pre_compute_text_embeddings",
@@ -854,7 +855,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -1057,7 +1058,7 @@ def load_model_hook(models, input_dir):
if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
- f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}. {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
@@ -1369,7 +1370,7 @@ def compute_text_embeddings(prompt):
model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
index 410cd74a5b7b..88f6ca0f4db6 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -782,7 +783,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -1021,7 +1022,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
index c02a59a0077a..64914f5204a4 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
@@ -118,7 +119,7 @@ def save_model_card(
)
model_description = f"""
-# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
+# {"SDXL" if "playground" not in base_model else "Playground"} LoRA DreamBooth - {repo_id}
@@ -376,7 +377,7 @@ def parse_args(input_args=None):
"--do_edm_style_training",
default=False,
action="store_true",
- help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
+ help="Flag to conduct training using the EDM formulation as introduced in https://huggingface.co/papers/2206.00364.",
)
parser.add_argument(
"--with_prior_preservation",
@@ -517,7 +518,7 @@ def parse_args(input_args=None):
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
@@ -661,7 +662,7 @@ def parse_args(input_args=None):
action="store_true",
default=False,
help=(
- "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
+ "Whether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://huggingface.co/papers/2402.09353. "
"Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
),
)
@@ -1054,7 +1055,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.do_edm_style_training and args.snr_gamma is not None:
@@ -1336,7 +1337,7 @@ def load_model_hook(models, input_dir):
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -1759,7 +1760,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
# We then precondition the final model inputs based on these sigmas instead of the timesteps.
- # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Follow: Section 5 of https://huggingface.co/papers/2206.00364.
if args.do_edm_style_training:
sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
if "EDM" in scheduler_type:
@@ -1819,7 +1820,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.do_edm_style_training:
# Similar to the input preconditioning, the model predictions are also preconditioned
# on noised model inputs (before preconditioning) and the sigmas.
- # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Follow: Section 5 of https://huggingface.co/papers/2206.00364.
if "EDM" in scheduler_type:
model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
else:
@@ -1873,7 +1874,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
weighting=weighting,
)
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
index 2ca555889cf9..c92b0ac0536d 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
@@ -353,7 +353,7 @@ def parse_args():
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -547,7 +547,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
@@ -1017,7 +1017,7 @@ def unwrap_model(model):
model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
index 3e6199a09a55..b7aa7b7bbbfd 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
@@ -268,7 +268,7 @@ def parse_args():
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -442,7 +442,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -849,7 +849,7 @@ def collate_fn(examples):
model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
index abc439912664..715852cb728b 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
@@ -345,7 +345,7 @@ def parse_args(input_args=None):
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--allow_tf32",
@@ -537,7 +537,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -750,7 +750,7 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -1170,7 +1170,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
index 4738e39e832e..5a26fd3074d1 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
@@ -388,7 +388,7 @@ def parse_args(input_args=None):
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
@@ -630,7 +630,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -1185,7 +1185,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/research_projects/sd3_lora_colab/README.md b/examples/research_projects/sd3_lora_colab/README.md
index b7d7eedfb5dc..be1bddf9830f 100644
--- a/examples/research_projects/sd3_lora_colab/README.md
+++ b/examples/research_projects/sd3_lora_colab/README.md
@@ -6,7 +6,7 @@ This is an **EDUCATIONAL** project that provides utilities for DreamBooth LoRA t
> SD3 is gated, so you need to make sure you agree to [share your contact info](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) to access the model before using it with Diffusers. Once you have access, you need to log in so your system knows you’re authorized. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
@@ -17,7 +17,7 @@ For setup, inference code, and details on how to run the code, please follow the
We make use of several techniques to make this possible:
-* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
+* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://huggingface.co/papers/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of:
* 8bit Adam for optimization through the `bitsandbytes` library.
* Gradient checkpointing and gradient accumulation.
diff --git a/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb b/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
index 8e8190a59324..79c3169b63c2 100644
--- a/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
+++ b/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
@@ -60,7 +60,7 @@
},
"outputs": [],
"source": [
- "!huggingface-cli login"
+ "!hf auth login"
]
},
{
@@ -2425,4 +2425,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
+}
\ No newline at end of file
diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
index f5bee58d4534..d73aab73630a 100644
--- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
+++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
@@ -623,7 +623,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
@@ -765,7 +765,7 @@ def load_model_hook(models, input_dir):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
- f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1001,7 +1001,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
return_dict=False,
)[0]
- # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Follow: Section 5 of https://huggingface.co/papers/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
diff --git a/examples/research_projects/vae/vae_roundtrip.py b/examples/research_projects/vae/vae_roundtrip.py
index 8388a352b2f2..922cb42615e9 100644
--- a/examples/research_projects/vae/vae_roundtrip.py
+++ b/examples/research_projects/vae/vae_roundtrip.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import typing
@@ -237,7 +238,7 @@ def parse_args() -> argparse.Namespace:
# EXAMPLE USAGE:
#
-# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
+# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "stable-diffusion-v1-5/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
#
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
#
diff --git a/examples/research_projects/wuerstchen/text_to_image/README.md b/examples/research_projects/wuerstchen/text_to_image/README.md
index a6ec4698b611..8df068a8735a 100644
--- a/examples/research_projects/wuerstchen/text_to_image/README.md
+++ b/examples/research_projects/wuerstchen/text_to_image/README.md
@@ -26,7 +26,7 @@ accelerate config
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
```bash
-huggingface-cli login
+hf auth login
```
## Prior training
@@ -61,7 +61,7 @@ accelerate launch train_text_to_image_prior.py \
## Training with LoRA
-Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
diff --git a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index 9e2302f1b1ba..fbf73a070e9f 100644
--- a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -446,7 +447,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
index 83647097d28a..737c70665bb0 100644
--- a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -444,7 +445,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/server-async/Pipelines.py b/examples/server-async/Pipelines.py
new file mode 100644
index 000000000000..f89cac6a7e4b
--- /dev/null
+++ b/examples/server-async/Pipelines.py
@@ -0,0 +1,91 @@
+import logging
+import os
+from dataclasses import dataclass, field
+from typing import List
+
+import torch
+from pydantic import BaseModel
+
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
+
+
+logger = logging.getLogger(__name__)
+
+
+class TextToImageInput(BaseModel):
+ model: str
+ prompt: str
+ size: str | None = None
+ n: int | None = None
+
+
+@dataclass
+class PresetModels:
+ SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
+ SD3_5: List[str] = field(
+ default_factory=lambda: [
+ "stabilityai/stable-diffusion-3.5-large",
+ "stabilityai/stable-diffusion-3.5-large-turbo",
+ "stabilityai/stable-diffusion-3.5-medium",
+ ]
+ )
+
+
+class TextToImagePipelineSD3:
+ def __init__(self, model_path: str | None = None):
+ self.model_path = model_path or os.getenv("MODEL_PATH")
+ self.pipeline: StableDiffusion3Pipeline | None = None
+ self.device: str | None = None
+
+ def start(self):
+ if torch.cuda.is_available():
+ model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
+ logger.info("Loading CUDA")
+ self.device = "cuda"
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.float16,
+ ).to(device=self.device)
+ elif torch.backends.mps.is_available():
+ model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
+ logger.info("Loading MPS for Mac M Series")
+ self.device = "mps"
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ ).to(device=self.device)
+ else:
+ raise Exception("No CUDA or MPS device available")
+
+
+class ModelPipelineInitializer:
+ def __init__(self, model: str = "", type_models: str = "t2im"):
+ self.model = model
+ self.type_models = type_models
+ self.pipeline = None
+ self.device = "cuda" if torch.cuda.is_available() else "mps"
+ self.model_type = None
+
+ def initialize_pipeline(self):
+ if not self.model:
+ raise ValueError("Model name not provided")
+
+ # Check if model exists in PresetModels
+ preset_models = PresetModels()
+
+ # Determine which model type we're dealing with
+ if self.model in preset_models.SD3:
+ self.model_type = "SD3"
+ elif self.model in preset_models.SD3_5:
+ self.model_type = "SD3_5"
+
+ # Create appropriate pipeline based on model type and type_models
+ if self.type_models == "t2im":
+ if self.model_type in ["SD3", "SD3_5"]:
+ self.pipeline = TextToImagePipelineSD3(self.model)
+ else:
+ raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
+ elif self.type_models == "t2v":
+ raise ValueError(f"Unsupported type_models: {self.type_models}")
+
+ return self.pipeline
diff --git a/examples/server-async/README.md b/examples/server-async/README.md
new file mode 100644
index 000000000000..a47ab7c7f224
--- /dev/null
+++ b/examples/server-async/README.md
@@ -0,0 +1,171 @@
+# Asynchronous server and parallel execution of models
+
+> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
+> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
+
+## ⚠️ IMPORTANT
+
+* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
+
+## Necessary components
+
+All the components needed to create the inference server are in the current directory:
+
+```
+server-async/
+├── utils/
+├─────── __init__.py
+├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
+├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
+├─────── utils.py # Image/video saving utilities and service configuration
+├── Pipelines.py # pipeline loader classes (SD3)
+├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
+├── test.py # Client test script for inference requests
+├── requirements.txt # Dependencies
+└── README.md # This documentation
+```
+
+## What `diffusers-async` adds / Why we needed it
+
+Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
+
+`diffusers-async` / this example addresses that by:
+
+* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
+* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
+* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
+* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
+* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
+* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
+
+## How the server works (high-level flow)
+
+1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
+2. On each HTTP inference request:
+
+ * The server uses `RequestScopedPipeline.generate(...)` which:
+
+ * automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
+ * obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
+ * does `local_pipe = copy.copy(base_pipe)` (shallow copy),
+ * sets `local_pipe.scheduler = local_scheduler` (if possible),
+ * clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
+ * wraps tokenizers with thread-safe locks to prevent race conditions,
+ * optionally enters a `model_cpu_offload_context()` for memory offload hooks,
+ * calls the pipeline on the local view (`local_pipe(...)`).
+3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
+4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
+
+## How to set up and run the server
+
+### 1) Install dependencies
+
+Recommended: create a virtualenv / conda environment.
+
+```bash
+pip install diffusers
+pip install -r requirements.txt
+```
+
+### 2) Start the server
+
+Using the `serverasync.py` file that already has everything you need:
+
+```bash
+python serverasync.py
+```
+
+The server will start on `http://localhost:8500` by default with the following features:
+- FastAPI application with async lifespan management
+- Automatic model loading and pipeline initialization
+- Request counting and active inference tracking
+- Memory cleanup after each inference
+- CORS middleware for cross-origin requests
+
+### 3) Test the server
+
+Use the included test script:
+
+```bash
+python test.py
+```
+
+Or send a manual request:
+
+`POST /api/diffusers/inference` with JSON body:
+
+```json
+{
+ "prompt": "A futuristic cityscape, vibrant colors",
+ "num_inference_steps": 30,
+ "num_images_per_prompt": 1
+}
+```
+
+Response example:
+
+```json
+{
+ "response": ["http://localhost:8500/images/img123.png"]
+}
+```
+
+### 4) Server endpoints
+
+- `GET /` - Welcome message
+- `POST /api/diffusers/inference` - Main inference endpoint
+- `GET /images/{filename}` - Serve generated images
+- `GET /api/status` - Server status and memory info
+
+## Advanced Configuration
+
+### RequestScopedPipeline Parameters
+
+```python
+RequestScopedPipeline(
+ pipeline, # Base pipeline to wrap
+ mutable_attrs=None, # Custom list of attributes to clone
+ auto_detect_mutables=True, # Enable automatic detection of mutable attributes
+ tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
+ tokenizer_lock=None, # Custom threading lock for tokenizers
+ wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
+)
+```
+
+### BaseAsyncScheduler Features
+
+* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
+* `clone_for_request()` method for safe per-request scheduler cloning
+* Enhanced debugging with `__repr__` and `__str__` methods
+* Full compatibility with existing scheduler APIs
+
+### Server Configuration
+
+The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
+
+```python
+@dataclass
+class ServerConfigModels:
+ model: str = 'stabilityai/stable-diffusion-3.5-medium'
+ type_models: str = 't2im'
+ host: str = '0.0.0.0'
+ port: int = 8500
+```
+
+## Troubleshooting (quick)
+
+* `Already borrowed` — previously a Rust tokenizer concurrency error.
+ ✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
+
+* `can't set attribute 'components'` — pipeline exposes read-only `components`.
+ ✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
+
+* Scheduler issues:
+ * If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
+ ✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
+
+* Memory issues with large tensors:
+ ✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
+
+* Automatic tokenizer detection:
+ ✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
\ No newline at end of file
diff --git a/examples/server-async/requirements.txt b/examples/server-async/requirements.txt
new file mode 100644
index 000000000000..aafa93b7023f
--- /dev/null
+++ b/examples/server-async/requirements.txt
@@ -0,0 +1,10 @@
+torch
+torchvision
+transformers
+sentencepiece
+fastapi
+uvicorn
+ftfy
+accelerate
+xformers
+protobuf
\ No newline at end of file
diff --git a/examples/server-async/serverasync.py b/examples/server-async/serverasync.py
new file mode 100644
index 000000000000..b279b36f9a84
--- /dev/null
+++ b/examples/server-async/serverasync.py
@@ -0,0 +1,230 @@
+import asyncio
+import gc
+import logging
+import os
+import random
+import threading
+from contextlib import asynccontextmanager
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Type
+
+import torch
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.concurrency import run_in_threadpool
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from Pipelines import ModelPipelineInitializer
+from pydantic import BaseModel
+
+from utils import RequestScopedPipeline, Utils
+
+
+@dataclass
+class ServerConfigModels:
+ model: str = "stabilityai/stable-diffusion-3.5-medium"
+ type_models: str = "t2im"
+ constructor_pipeline: Optional[Type] = None
+ custom_pipeline: Optional[Type] = None
+ components: Optional[Dict[str, Any]] = None
+ torch_dtype: Optional[torch.dtype] = None
+ host: str = "0.0.0.0"
+ port: int = 8500
+
+
+server_config = ServerConfigModels()
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ logging.basicConfig(level=logging.INFO)
+ app.state.logger = logging.getLogger("diffusers-server")
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
+
+ app.state.total_requests = 0
+ app.state.active_inferences = 0
+ app.state.metrics_lock = asyncio.Lock()
+ app.state.metrics_task = None
+
+ app.state.utils_app = Utils(
+ host=server_config.host,
+ port=server_config.port,
+ )
+
+ async def metrics_loop():
+ try:
+ while True:
+ async with app.state.metrics_lock:
+ total = app.state.total_requests
+ active = app.state.active_inferences
+ app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
+ await asyncio.sleep(5)
+ except asyncio.CancelledError:
+ app.state.logger.info("Metrics loop cancelled")
+ raise
+
+ app.state.metrics_task = asyncio.create_task(metrics_loop())
+
+ try:
+ yield
+ finally:
+ task = app.state.metrics_task
+ if task:
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+ try:
+ stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
+ if callable(stop_fn):
+ await run_in_threadpool(stop_fn)
+ except Exception as e:
+ app.state.logger.warning(f"Error during pipeline shutdown: {e}")
+
+ app.state.logger.info("Lifespan shutdown complete")
+
+
+app = FastAPI(lifespan=lifespan)
+
+logger = logging.getLogger("DiffusersServer.Pipelines")
+
+
+initializer = ModelPipelineInitializer(
+ model=server_config.model,
+ type_models=server_config.type_models,
+)
+model_pipeline = initializer.initialize_pipeline()
+model_pipeline.start()
+
+request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
+pipeline_lock = threading.Lock()
+
+logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
+
+app.state.MODEL_INITIALIZER = initializer
+app.state.MODEL_PIPELINE = model_pipeline
+app.state.REQUEST_PIPE = request_pipe
+app.state.PIPELINE_LOCK = pipeline_lock
+
+
+class JSONBodyQueryAPI(BaseModel):
+ model: str | None = None
+ prompt: str
+ negative_prompt: str | None = None
+ num_inference_steps: int = 28
+ num_images_per_prompt: int = 1
+
+
+@app.middleware("http")
+async def count_requests_middleware(request: Request, call_next):
+ async with app.state.metrics_lock:
+ app.state.total_requests += 1
+ response = await call_next(request)
+ return response
+
+
+@app.get("/")
+async def root():
+ return {"message": "Welcome to the Diffusers Server"}
+
+
+@app.post("/api/diffusers/inference")
+async def api(json: JSONBodyQueryAPI):
+ prompt = json.prompt
+ negative_prompt = json.negative_prompt or ""
+ num_steps = json.num_inference_steps
+ num_images_per_prompt = json.num_images_per_prompt
+
+ wrapper = app.state.MODEL_PIPELINE
+ initializer = app.state.MODEL_INITIALIZER
+
+ utils_app = app.state.utils_app
+
+ if not wrapper or not wrapper.pipeline:
+ raise HTTPException(500, "Model not initialized correctly")
+ if not prompt.strip():
+ raise HTTPException(400, "No prompt provided")
+
+ def make_generator():
+ g = torch.Generator(device=initializer.device)
+ return g.manual_seed(random.randint(0, 10_000_000))
+
+ req_pipe = app.state.REQUEST_PIPE
+
+ def infer():
+ gen = make_generator()
+ return req_pipe.generate(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ generator=gen,
+ num_inference_steps=num_steps,
+ num_images_per_prompt=num_images_per_prompt,
+ device=initializer.device,
+ output_type="pil",
+ )
+
+ try:
+ async with app.state.metrics_lock:
+ app.state.active_inferences += 1
+
+ output = await run_in_threadpool(infer)
+
+ async with app.state.metrics_lock:
+ app.state.active_inferences = max(0, app.state.active_inferences - 1)
+
+ urls = [utils_app.save_image(img) for img in output.images]
+ return {"response": urls}
+
+ except Exception as e:
+ async with app.state.metrics_lock:
+ app.state.active_inferences = max(0, app.state.active_inferences - 1)
+ logger.error(f"Error during inference: {e}")
+ raise HTTPException(500, f"Error in processing: {e}")
+
+ finally:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.ipc_collect()
+ gc.collect()
+
+
+@app.get("/images/{filename}")
+async def serve_image(filename: str):
+ utils_app = app.state.utils_app
+ file_path = os.path.join(utils_app.image_dir, filename)
+ if not os.path.isfile(file_path):
+ raise HTTPException(status_code=404, detail="Image not found")
+ return FileResponse(file_path, media_type="image/png")
+
+
+@app.get("/api/status")
+async def get_status():
+ memory_info = {}
+ if torch.cuda.is_available():
+ memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
+ memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
+ memory_info = {
+ "memory_allocated_gb": round(memory_allocated, 2),
+ "memory_reserved_gb": round(memory_reserved, 2),
+ "device": torch.cuda.get_device_name(0),
+ }
+
+ return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
+
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+if __name__ == "__main__":
+ import uvicorn
+
+ uvicorn.run(app, host=server_config.host, port=server_config.port)
diff --git a/examples/server-async/test.py b/examples/server-async/test.py
new file mode 100644
index 000000000000..e67317ea8f6b
--- /dev/null
+++ b/examples/server-async/test.py
@@ -0,0 +1,65 @@
+import os
+import time
+import urllib.parse
+
+import requests
+
+
+SERVER_URL = "http://localhost:8500/api/diffusers/inference"
+BASE_URL = "http://localhost:8500"
+DOWNLOAD_FOLDER = "generated_images"
+WAIT_BEFORE_DOWNLOAD = 2 # seconds
+
+os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
+
+
+def save_from_url(url: str) -> str:
+ """Download the given URL (relative or absolute) and save it locally."""
+ if url.startswith("/"):
+ direct = BASE_URL.rstrip("/") + url
+ else:
+ direct = url
+ resp = requests.get(direct, timeout=60)
+ resp.raise_for_status()
+ filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
+ path = os.path.join(DOWNLOAD_FOLDER, filename)
+ with open(path, "wb") as f:
+ f.write(resp.content)
+ return path
+
+
+def main():
+ payload = {
+ "prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
+ "num_inference_steps": 30,
+ "num_images_per_prompt": 1,
+ }
+
+ print("Sending request...")
+ try:
+ r = requests.post(SERVER_URL, json=payload, timeout=480)
+ r.raise_for_status()
+ except Exception as e:
+ print(f"Request failed: {e}")
+ return
+
+ body = r.json().get("response", [])
+ # Normalize to a list
+ urls = body if isinstance(body, list) else [body] if body else []
+ if not urls:
+ print("No URLs found in the response. Check the server output.")
+ return
+
+ print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
+ time.sleep(WAIT_BEFORE_DOWNLOAD)
+
+ for u in urls:
+ try:
+ path = save_from_url(u)
+ print(f"Image saved to: {path}")
+ except Exception as e:
+ print(f"Error downloading {u}: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/server-async/utils/__init__.py b/examples/server-async/utils/__init__.py
new file mode 100644
index 000000000000..731cfe491ae5
--- /dev/null
+++ b/examples/server-async/utils/__init__.py
@@ -0,0 +1,2 @@
+from .requestscopedpipeline import RequestScopedPipeline
+from .utils import Utils
diff --git a/examples/server-async/utils/requestscopedpipeline.py b/examples/server-async/utils/requestscopedpipeline.py
new file mode 100644
index 000000000000..57d1e2567169
--- /dev/null
+++ b/examples/server-async/utils/requestscopedpipeline.py
@@ -0,0 +1,296 @@
+import copy
+import threading
+from typing import Any, Iterable, List, Optional
+
+import torch
+
+from diffusers.utils import logging
+
+from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
+
+
+logger = logging.get_logger(__name__)
+
+
+def safe_tokenize(tokenizer, *args, lock, **kwargs):
+ with lock:
+ return tokenizer(*args, **kwargs)
+
+
+class RequestScopedPipeline:
+ DEFAULT_MUTABLE_ATTRS = [
+ "_all_hooks",
+ "_offload_device",
+ "_progress_bar_config",
+ "_progress_bar",
+ "_rng_state",
+ "_last_seed",
+ "latents",
+ ]
+
+ def __init__(
+ self,
+ pipeline: Any,
+ mutable_attrs: Optional[Iterable[str]] = None,
+ auto_detect_mutables: bool = True,
+ tensor_numel_threshold: int = 1_000_000,
+ tokenizer_lock: Optional[threading.Lock] = None,
+ wrap_scheduler: bool = True,
+ ):
+ self._base = pipeline
+ self.unet = getattr(pipeline, "unet", None)
+ self.vae = getattr(pipeline, "vae", None)
+ self.text_encoder = getattr(pipeline, "text_encoder", None)
+ self.components = getattr(pipeline, "components", None)
+
+ if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
+ if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
+ pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
+
+ self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
+ self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
+
+ self._auto_detect_mutables = bool(auto_detect_mutables)
+ self._tensor_numel_threshold = int(tensor_numel_threshold)
+
+ self._auto_detected_attrs: List[str] = []
+
+ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
+ base_sched = getattr(self._base, "scheduler", None)
+ if base_sched is None:
+ return None
+
+ if not isinstance(base_sched, BaseAsyncScheduler):
+ wrapped_scheduler = BaseAsyncScheduler(base_sched)
+ else:
+ wrapped_scheduler = base_sched
+
+ try:
+ return wrapped_scheduler.clone_for_request(
+ num_inference_steps=num_inference_steps, device=device, **clone_kwargs
+ )
+ except Exception as e:
+ logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
+ try:
+ return copy.deepcopy(wrapped_scheduler)
+ except Exception as e:
+ logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
+ return wrapped_scheduler
+
+ def _autodetect_mutables(self, max_attrs: int = 40):
+ if not self._auto_detect_mutables:
+ return []
+
+ if self._auto_detected_attrs:
+ return self._auto_detected_attrs
+
+ candidates: List[str] = []
+ seen = set()
+ for name in dir(self._base):
+ if name.startswith("__"):
+ continue
+ if name in self._mutable_attrs:
+ continue
+ if name in ("to", "save_pretrained", "from_pretrained"):
+ continue
+ try:
+ val = getattr(self._base, name)
+ except Exception:
+ continue
+
+ import types
+
+ # skip callables and modules
+ if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
+ continue
+
+ # containers -> candidate
+ if isinstance(val, (dict, list, set, tuple, bytearray)):
+ candidates.append(name)
+ seen.add(name)
+ else:
+ # try Tensor detection
+ try:
+ if isinstance(val, torch.Tensor):
+ if val.numel() <= self._tensor_numel_threshold:
+ candidates.append(name)
+ seen.add(name)
+ else:
+ logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
+ except Exception:
+ continue
+
+ if len(candidates) >= max_attrs:
+ break
+
+ self._auto_detected_attrs = candidates
+ logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
+ return self._auto_detected_attrs
+
+ def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
+ try:
+ cls = type(base_obj)
+ descriptor = getattr(cls, attr_name, None)
+ if isinstance(descriptor, property):
+ return descriptor.fset is None
+ if hasattr(descriptor, "__set__") is False and descriptor is not None:
+ return False
+ except Exception:
+ pass
+ return False
+
+ def _clone_mutable_attrs(self, base, local):
+ attrs_to_clone = list(self._mutable_attrs)
+ attrs_to_clone.extend(self._autodetect_mutables())
+
+ EXCLUDE_ATTRS = {
+ "components",
+ }
+
+ for attr in attrs_to_clone:
+ if attr in EXCLUDE_ATTRS:
+ logger.debug(f"Skipping excluded attr '{attr}'")
+ continue
+ if not hasattr(base, attr):
+ continue
+ if self._is_readonly_property(base, attr):
+ logger.debug(f"Skipping read-only property '{attr}'")
+ continue
+
+ try:
+ val = getattr(base, attr)
+ except Exception as e:
+ logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
+ continue
+
+ try:
+ if isinstance(val, dict):
+ setattr(local, attr, dict(val))
+ elif isinstance(val, (list, tuple, set)):
+ setattr(local, attr, list(val))
+ elif isinstance(val, bytearray):
+ setattr(local, attr, bytearray(val))
+ else:
+ # small tensors or atomic values
+ if isinstance(val, torch.Tensor):
+ if val.numel() <= self._tensor_numel_threshold:
+ setattr(local, attr, val.clone())
+ else:
+ # don't clone big tensors, keep reference
+ setattr(local, attr, val)
+ else:
+ try:
+ setattr(local, attr, copy.copy(val))
+ except Exception:
+ setattr(local, attr, val)
+ except (AttributeError, TypeError) as e:
+ logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
+ continue
+ except Exception as e:
+ logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
+ continue
+
+ def _is_tokenizer_component(self, component) -> bool:
+ if component is None:
+ return False
+
+ tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
+ has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
+
+ class_name = component.__class__.__name__.lower()
+ has_tokenizer_in_name = "tokenizer" in class_name
+
+ tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
+ has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
+
+ return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
+
+ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
+ local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
+
+ try:
+ local_pipe = copy.copy(self._base)
+ except Exception as e:
+ logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
+ local_pipe = copy.deepcopy(self._base)
+
+ if local_scheduler is not None:
+ try:
+ timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
+ local_scheduler.scheduler,
+ num_inference_steps=num_inference_steps,
+ device=device,
+ return_scheduler=True,
+ **{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
+ )
+
+ final_scheduler = BaseAsyncScheduler(configured_scheduler)
+ setattr(local_pipe, "scheduler", final_scheduler)
+ except Exception:
+ logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
+
+ self._clone_mutable_attrs(self._base, local_pipe)
+
+ # 4) wrap tokenizers on the local pipe with the lock wrapper
+ tokenizer_wrappers = {} # name -> original_tokenizer
+ try:
+ # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
+ for name in dir(local_pipe):
+ if "tokenizer" in name and not name.startswith("_"):
+ tok = getattr(local_pipe, name, None)
+ if tok is not None and self._is_tokenizer_component(tok):
+ tokenizer_wrappers[name] = tok
+ setattr(
+ local_pipe,
+ name,
+ lambda *args, tok=tok, **kwargs: safe_tokenize(
+ tok, *args, lock=self._tokenizer_lock, **kwargs
+ ),
+ )
+
+ # b) wrap tokenizers in components dict
+ if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
+ for key, val in local_pipe.components.items():
+ if val is None:
+ continue
+
+ if self._is_tokenizer_component(val):
+ tokenizer_wrappers[f"components[{key}]"] = val
+ local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
+ tokenizer, *args, lock=self._tokenizer_lock, **kwargs
+ )
+
+ except Exception as e:
+ logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
+
+ result = None
+ cm = getattr(local_pipe, "model_cpu_offload_context", None)
+ try:
+ if callable(cm):
+ try:
+ with cm():
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ except TypeError:
+ # cm might be a context manager instance rather than callable
+ try:
+ with cm:
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ except Exception as e:
+ logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ else:
+ # no offload context available — call directly
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+
+ return result
+
+ finally:
+ try:
+ for name, tok in tokenizer_wrappers.items():
+ if name.startswith("components["):
+ key = name[len("components[") : -1]
+ local_pipe.components[key] = tok
+ else:
+ setattr(local_pipe, name, tok)
+ except Exception as e:
+ logger.debug(f"Error restoring wrapped tokenizers: {e}")
diff --git a/examples/server-async/utils/scheduler.py b/examples/server-async/utils/scheduler.py
new file mode 100644
index 000000000000..86d47cac6154
--- /dev/null
+++ b/examples/server-async/utils/scheduler.py
@@ -0,0 +1,141 @@
+import copy
+import inspect
+from typing import Any, List, Optional, Union
+
+import torch
+
+
+class BaseAsyncScheduler:
+ def __init__(self, scheduler: Any):
+ self.scheduler = scheduler
+
+ def __getattr__(self, name: str):
+ if hasattr(self.scheduler, name):
+ return getattr(self.scheduler, name)
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
+
+ def __setattr__(self, name: str, value):
+ if name == "scheduler":
+ super().__setattr__(name, value)
+ else:
+ if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
+ setattr(self.scheduler, name, value)
+ else:
+ super().__setattr__(name, value)
+
+ def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
+ local = copy.deepcopy(self.scheduler)
+ local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
+ cloned = self.__class__(local)
+ return cloned
+
+ def __repr__(self):
+ return f"BaseAsyncScheduler({repr(self.scheduler)})"
+
+ def __str__(self):
+ return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
+
+
+def async_retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
+ Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Backwards compatible: by default the function behaves exactly as before and returns
+ (timesteps_tensor, num_inference_steps)
+
+ If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
+ scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
+ or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
+ (timesteps_tensor, num_inference_steps, scheduler_in_use)
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Optional kwargs:
+ return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
+ where `scheduler_in_use` is a scheduler instance that already has timesteps set.
+ This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
+
+ Returns:
+ `(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
+ `(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
+ """
+ # pop our optional control kwarg (keeps compatibility)
+ return_scheduler = bool(kwargs.pop("return_scheduler", False))
+
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+
+ # choose scheduler to call set_timesteps on
+ scheduler_in_use = scheduler
+ if return_scheduler:
+ # Do not mutate the provided scheduler: prefer to clone if possible
+ if hasattr(scheduler, "clone_for_request"):
+ try:
+ # clone_for_request may accept num_inference_steps or other kwargs; be permissive
+ scheduler_in_use = scheduler.clone_for_request(
+ num_inference_steps=num_inference_steps or 0, device=device
+ )
+ except Exception:
+ scheduler_in_use = copy.deepcopy(scheduler)
+ else:
+ # fallback deepcopy (scheduler tends to be smallish - acceptable)
+ scheduler_in_use = copy.deepcopy(scheduler)
+
+ # helper to test if set_timesteps supports a particular kwarg
+ def _accepts(param_name: str) -> bool:
+ try:
+ return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
+ except (ValueError, TypeError):
+ # if signature introspection fails, be permissive and attempt the call later
+ return False
+
+ # now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
+ if timesteps is not None:
+ accepts_timesteps = _accepts("timesteps")
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+ num_inference_steps = len(timesteps_out)
+ elif sigmas is not None:
+ accept_sigmas = _accepts("sigmas")
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+ num_inference_steps = len(timesteps_out)
+ else:
+ # default path
+ scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+
+ if return_scheduler:
+ return timesteps_out, num_inference_steps, scheduler_in_use
+ return timesteps_out, num_inference_steps
diff --git a/examples/server-async/utils/utils.py b/examples/server-async/utils/utils.py
new file mode 100644
index 000000000000..9f943305126c
--- /dev/null
+++ b/examples/server-async/utils/utils.py
@@ -0,0 +1,48 @@
+import gc
+import logging
+import os
+import tempfile
+import uuid
+
+import torch
+
+
+logger = logging.getLogger(__name__)
+
+
+class Utils:
+ def __init__(self, host: str = "0.0.0.0", port: int = 8500):
+ self.service_url = f"http://{host}:{port}"
+ self.image_dir = os.path.join(tempfile.gettempdir(), "images")
+ if not os.path.exists(self.image_dir):
+ os.makedirs(self.image_dir)
+
+ self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
+ if not os.path.exists(self.video_dir):
+ os.makedirs(self.video_dir)
+
+ def save_image(self, image):
+ if hasattr(image, "to"):
+ try:
+ image = image.to("cpu")
+ except Exception:
+ pass
+
+ if isinstance(image, torch.Tensor):
+ from torchvision import transforms
+
+ to_pil = transforms.ToPILImage()
+ image = to_pil(image.squeeze(0).clamp(0, 1))
+
+ filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
+ image_path = os.path.join(self.image_dir, filename)
+ logger.info(f"Saving image to {image_path}")
+
+ image.save(image_path, format="PNG", optimize=True)
+
+ del image
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return os.path.join(self.service_url, "images", filename)
diff --git a/examples/server/README.md b/examples/server/README.md
index 8ad0ed3cbe6a..f8cd58fc1c89 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -9,8 +9,8 @@ This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server
Start by navigating to the `examples/server` folder and installing all of the dependencies.
```py
-pip install .
-pip install -f requirements.txt
+pip install diffusers
+pip install -r requirements.txt
```
Launch the server with the following command.
diff --git a/examples/server/requirements.in b/examples/server/requirements.in
index b49b285a8fc8..f8c35d48cdac 100644
--- a/examples/server/requirements.in
+++ b/examples/server/requirements.in
@@ -1,4 +1,4 @@
-torch~=2.4.0
+torch~=2.7.0
transformers==4.46.1
sentencepiece
aiohttp
@@ -6,4 +6,5 @@ py-consul
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
fastapi
-uvicorn
\ No newline at end of file
+uvicorn
+accelerate
diff --git a/examples/server/requirements.txt b/examples/server/requirements.txt
index 065a381f0c9b..688a4ee94fd1 100644
--- a/examples/server/requirements.txt
+++ b/examples/server/requirements.txt
@@ -1,15 +1,17 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
-aiohappyeyeballs==2.4.3
+aiohappyeyeballs==2.6.1
# via aiohttp
-aiohttp==3.10.10
+aiohttp==3.12.14
# via -r requirements.in
-aiosignal==1.3.1
+aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
anyio==4.6.2.post1
# via starlette
+async-timeout==4.0.3
+ # via aiohttp
attrs==24.2.0
# via aiohttp
certifi==2024.8.30
@@ -18,6 +20,8 @@ charset-normalizer==3.4.0
# via requests
click==8.1.7
# via uvicorn
+exceptiongroup==1.3.0
+ # via anyio
fastapi==0.115.3
# via -r requirements.in
filelock==3.16.1
@@ -35,7 +39,7 @@ fsspec==2024.10.0
# torch
h11==0.14.0
# via uvicorn
-huggingface-hub==0.26.1
+huggingface-hub==0.35.0
# via
# tokenizers
# transformers
@@ -54,10 +58,47 @@ multidict==6.1.0
# via
# aiohttp
# yarl
-networkx==3.4.2
+networkx==3.2.1
# via torch
-numpy==2.1.2
+numpy==2.0.2
# via transformers
+nvidia-cublas-cu12==12.6.4.1
+ # via
+ # nvidia-cudnn-cu12
+ # nvidia-cusolver-cu12
+ # torch
+nvidia-cuda-cupti-cu12==12.6.80
+ # via torch
+nvidia-cuda-nvrtc-cu12==12.6.77
+ # via torch
+nvidia-cuda-runtime-cu12==12.6.77
+ # via torch
+nvidia-cudnn-cu12==9.5.1.17
+ # via torch
+nvidia-cufft-cu12==11.3.0.4
+ # via torch
+nvidia-cufile-cu12==1.11.1.6
+ # via torch
+nvidia-curand-cu12==10.3.7.77
+ # via torch
+nvidia-cusolver-cu12==11.7.1.2
+ # via torch
+nvidia-cusparse-cu12==12.5.4.2
+ # via
+ # nvidia-cusolver-cu12
+ # torch
+nvidia-cusparselt-cu12==0.6.3
+ # via torch
+nvidia-nccl-cu12==2.26.2
+ # via torch
+nvidia-nvjitlink-cu12==12.6.85
+ # via
+ # nvidia-cufft-cu12
+ # nvidia-cusolver-cu12
+ # nvidia-cusparse-cu12
+ # torch
+nvidia-nvtx-cu12==12.6.77
+ # via torch
packaging==24.1
# via
# huggingface-hub
@@ -69,7 +110,9 @@ prometheus-client==0.21.0
prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in
propcache==0.2.0
- # via yarl
+ # via
+ # aiohttp
+ # yarl
py-consul==1.5.3
# via -r requirements.in
pydantic==2.9.2
@@ -101,7 +144,7 @@ sympy==1.13.3
# via torch
tokenizers==0.20.1
# via transformers
-torch==2.4.1
+torch==2.7.0
# via -r requirements.in
tqdm==4.66.5
# via
@@ -109,16 +152,24 @@ tqdm==4.66.5
# transformers
transformers==4.46.1
# via -r requirements.in
+triton==3.3.0
+ # via torch
typing-extensions==4.12.2
# via
+ # aiosignal
+ # anyio
+ # exceptiongroup
# fastapi
# huggingface-hub
+ # multidict
# pydantic
# pydantic-core
+ # starlette
# torch
-urllib3==2.2.3
+ # uvicorn
+urllib3==2.5.0
# via requests
uvicorn==0.32.0
# via -r requirements.in
-yarl==1.16.0
+yarl==1.18.3
# via aiohttp
diff --git a/examples/t2i_adapter/README_sdxl.md b/examples/t2i_adapter/README_sdxl.md
index 1e5a19fedad1..0a3b5e33d46a 100644
--- a/examples/t2i_adapter/README_sdxl.md
+++ b/examples/t2i_adapter/README_sdxl.md
@@ -58,7 +58,7 @@ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/ma
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
-Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.
+Then run `hf auth login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.
```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
diff --git a/examples/t2i_adapter/test_t2i_adapter.py b/examples/t2i_adapter/test_t2i_adapter.py
index cdf124cdd932..be7331d024ca 100644
--- a/examples/t2i_adapter/test_t2i_adapter.py
+++ b/examples/t2i_adapter/test_t2i_adapter.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index a34ecf17eb30..989ac6e0c45e 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import functools
@@ -60,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
@@ -783,7 +784,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -1190,7 +1191,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
bsz = latents.shape[0]
# Cubic sampling to sample a random timestep for each image.
- # For more details about why cubic sampling is used, refer to section 3.4 of https://arxiv.org/abs/2302.08453
+ # For more details about why cubic sampling is used, refer to section 3.4 of https://huggingface.co/papers/2302.08453
timesteps = torch.rand((bsz,), device=latents.device)
timesteps = (1 - timesteps**3) * noise_scheduler.config.num_train_timesteps
timesteps = timesteps.long().to(noise_scheduler.timesteps.dtype)
diff --git a/examples/test_examples_utils.py b/examples/test_examples_utils.py
index b57a35e1f16e..f3f3d7541cb2 100644
--- a/examples/test_examples_utils.py
+++ b/examples/test_examples_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md
index 0bdf02f804bd..ebbf0a96becc 100644
--- a/examples/text_to_image/README.md
+++ b/examples/text_to_image/README.md
@@ -43,7 +43,7 @@ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
If you have already cloned the repo, then you won't need to go through these steps.
@@ -156,7 +156,7 @@ accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \
#### Training with Min-SNR weighting
-We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
+We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://huggingface.co/papers/2303.09556) which helps to achieve faster convergence
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
value when using it is 5.0.
@@ -179,13 +179,13 @@ EMA weights require an additional full-precision copy of the model parameters to
#### Training with DREAM
-We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
+We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://huggingface.co/papers/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
## Training with LoRA
-Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
@@ -215,7 +215,7 @@ For this example we want to directly store the trained LoRA embeddings on the Hu
we need to be logged in and add the `--push_to_hub` flag.
```bash
-huggingface-cli login
+hf auth login
```
Now we can start training!
diff --git a/examples/text_to_image/README_sdxl.md b/examples/text_to_image/README_sdxl.md
index 08d82ac133f4..6fb10ec9e1b3 100644
--- a/examples/text_to_image/README_sdxl.md
+++ b/examples/text_to_image/README_sdxl.md
@@ -127,7 +127,7 @@ boost.
## LoRA training example for Stable Diffusion XL (SDXL)
-Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
+Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://huggingface.co/papers/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
@@ -156,7 +156,7 @@ For this example we want to directly store the trained LoRA embeddings on the Hu
we need to be logged in and add the `--push_to_hub` flag.
```bash
-huggingface-cli login
+hf auth login
```
Now we can start training!
diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt
index c3ffa42f0edc..be05fe3fcdc5 100644
--- a/examples/text_to_image/requirements.txt
+++ b/examples/text_to_image/requirements.txt
@@ -5,4 +5,4 @@ datasets>=2.19.1
ftfy
tensorboard
Jinja2
-peft==0.7.0
+peft>=0.17.0
diff --git a/examples/text_to_image/requirements_sdxl.txt b/examples/text_to_image/requirements_sdxl.txt
index 64cbc9205fd0..4dacc26ce4bb 100644
--- a/examples/text_to_image/requirements_sdxl.txt
+++ b/examples/text_to_image/requirements_sdxl.txt
@@ -5,4 +5,4 @@ ftfy
tensorboard
Jinja2
datasets
-peft==0.7.0
\ No newline at end of file
+peft>=0.17.0
\ No newline at end of file
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index adfb7b74477f..7ebf7b5465a5 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -57,7 +57,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -359,14 +359,14 @@ def parse_args():
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--dream_training",
action="store_true",
help=(
"Use the DREAM training method, which makes training more efficient and accurate at the "
- "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210"
+ "expense of doing an extra forward pass. See: https://huggingface.co/papers/2312.00210"
),
)
parser.add_argument(
@@ -499,6 +499,15 @@ def parse_args():
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
+ parser.add_argument(
+ "--image_interpolation_mode",
+ type=str,
+ default="lanczos",
+ choices=[
+ f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
+ ],
+ help="The image interpolation method to use for resizing images.",
+ )
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -522,7 +531,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
@@ -787,10 +796,17 @@ def tokenize_captions(examples, is_train=True):
)
return inputs.input_ids
- # Preprocessing the datasets.
+ # Get the specified interpolation method from the args
+ interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
+
+ # Raise an error if the interpolation method is invalid
+ if interpolation is None:
+ raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
+
+ # Data preprocessing transformations
train_transforms = transforms.Compose(
[
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
@@ -1006,7 +1022,7 @@ def unwrap_model(model):
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index 4564c1d16f45..c4f36879f328 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -49,7 +49,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = logging.getLogger(__name__)
@@ -264,7 +264,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging.basicConfig(
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 82c395c685f8..1fd48dcd159d 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -37,7 +37,7 @@
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
-from peft.utils import get_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
@@ -46,7 +46,12 @@
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr
-from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -56,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -314,7 +319,7 @@ def parse_args():
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
@@ -418,6 +423,15 @@ def parse_args():
default=4,
help=("The dimension of the LoRA update matrices."),
)
+ parser.add_argument(
+ "--image_interpolation_mode",
+ type=str,
+ default="lanczos",
+ choices=[
+ f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
+ ],
+ help="The image interpolation method to use for resizing images.",
+ )
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -441,7 +455,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -649,10 +663,17 @@ def tokenize_captions(examples, is_train=True):
)
return inputs.input_ids
- # Preprocessing the datasets.
+ # Get the specified interpolation method from the args
+ interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
+
+ # Raise an error if the interpolation method is invalid
+ if interpolation is None:
+ raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
+
+ # Data preprocessing transformations
train_transforms = transforms.Compose(
[
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
@@ -692,6 +713,56 @@ def collate_fn(examples):
num_workers=args.dataloader_num_workers,
)
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ unet_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"Unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ safe_serialization=True,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # returns a tuple of state dictionary and network alphas
+ lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
+
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ # throw warning if some unexpected keys are found and continue loading
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32
+ if args.mixed_precision in ["fp16"]:
+ cast_training_params([unet_], dtype=torch.float32)
+
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
@@ -716,6 +787,10 @@ def collate_fn(examples):
unet, optimizer, train_dataloader, lr_scheduler
)
+ # Register the hooks for efficient saving and loading of LoRA weights
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
@@ -829,7 +904,7 @@ def collate_fn(examples):
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
@@ -890,17 +965,6 @@ def collate_fn(examples):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
- unwrapped_unet = unwrap_model(unet)
- unet_lora_state_dict = convert_state_dict_to_diffusers(
- get_peft_model_state_dict(unwrapped_unet)
- )
-
- StableDiffusionPipeline.save_lora_weights(
- save_directory=save_path,
- unet_lora_layers=unet_lora_state_dict,
- safe_serialization=True,
- )
-
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index 2061f0c6775b..5fb1825f37d3 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -68,7 +68,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -392,7 +392,7 @@ def parse_args(input_args=None):
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument(
"--allow_tf32",
@@ -480,6 +480,15 @@ def parse_args(input_args=None):
action="store_true",
help="debug loss for each image, if filenames are available in the dataset",
)
+ parser.add_argument(
+ "--image_interpolation_mode",
+ type=str,
+ default="lanczos",
+ choices=[
+ f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
+ ],
+ help="The image interpolation method to use for resizing images.",
+ )
if input_args is not None:
args = parser.parse_args(input_args)
@@ -546,7 +555,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -767,7 +776,7 @@ def load_model_hook(models, input_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, _ = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
- unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
if incompatible_keys is not None:
@@ -913,8 +922,14 @@ def tokenize_captions(examples, is_train=True):
tokens_two = tokenize_prompt(tokenizer_two, captions)
return tokens_one, tokens_two
+ # Get the specified interpolation method from the args
+ interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
+
+ # Raise an error if the interpolation method is invalid
+ if interpolation is None:
+ raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
# Preprocessing the datasets.
- train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_resize = transforms.Resize(args.resolution, interpolation=interpolation) # Use dynamic interpolation method
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
@@ -1163,7 +1178,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index 29da1f2efbaa..c26cb4484125 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -55,7 +55,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -392,7 +392,7 @@ def parse_args(input_args=None):
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
- "More details here: https://arxiv.org/abs/2303.09556.",
+ "More details here: https://huggingface.co/papers/2303.09556.",
)
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
parser.add_argument(
@@ -470,6 +470,15 @@ def parse_args(input_args=None):
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--image_interpolation_mode",
+ type=str,
+ default="lanczos",
+ choices=[
+ f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
+ ],
+ help="The image interpolation method to use for resizing images.",
+ )
if input_args is not None:
args = parser.parse_args(input_args)
@@ -592,7 +601,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
@@ -861,7 +870,10 @@ def load_model_hook(models, input_dir):
)
# Preprocessing the datasets.
- train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
+ if interpolation is None:
+ raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
+ train_resize = transforms.Resize(args.resolution, interpolation=interpolation)
train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
@@ -1136,7 +1148,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
- # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Compute loss-weights as per Section 3.4 of https://huggingface.co/papers/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md
index e869bb38d252..06e22dbcd804 100644
--- a/examples/textual_inversion/README.md
+++ b/examples/textual_inversion/README.md
@@ -1,6 +1,6 @@
## Textual Inversion fine-tuning example
-[Textual inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
+[Textual inversion](https://huggingface.co/papers/2208.01618) is a method to personalize text2image models like stable diffusion on your own images using just 3-5 examples.
The `textual_inversion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
## Running on Colab
@@ -41,7 +41,7 @@ accelerate config
First, let's login so that we can upload the checkpoint to the Hub during training:
```bash
-huggingface-cli login
+hf auth login
```
Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
@@ -86,7 +86,7 @@ accelerate launch textual_inversion.py \
A full training run takes ~1 hour on one V100 GPU.
-**Note**: As described in [the official paper](https://arxiv.org/abs/2208.01618)
+**Note**: As described in [the official paper](https://huggingface.co/papers/2208.01618)
only one embedding vector is used for the placeholder token, *e.g.* `""`.
However, one can also add multiple embedding vectors for the placeholder token
to increase the number of fine-tuneable parameters. This can help the model to learn
diff --git a/examples/textual_inversion/test_textual_inversion.py b/examples/textual_inversion/test_textual_inversion.py
index fa0b2c2bcf09..baf8692f2275 100644
--- a/examples/textual_inversion/test_textual_inversion.py
+++ b/examples/textual_inversion/test_textual_inversion.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/textual_inversion/test_textual_inversion_sdxl.py b/examples/textual_inversion/test_textual_inversion_sdxl.py
index a861708a70f9..3af75b44ee5f 100644
--- a/examples/textual_inversion/test_textual_inversion_sdxl.py
+++ b/examples/textual_inversion/test_textual_inversion_sdxl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 757a12045f10..caa77e4bbaf5 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -81,7 +82,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
@@ -594,7 +595,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
@@ -789,7 +790,7 @@ def main():
text_encoder, optimizer, train_dataloader, lr_scheduler
)
- # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
@@ -910,9 +911,9 @@ def main():
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
+ accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index 3ee675e76bbb..4a03d9bf6ba9 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = logging.getLogger(__name__)
@@ -166,7 +166,7 @@ def parse_args():
"--use_auth_token",
action="store_true",
help=(
- "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ "Will use the token generated when running `hf auth login` (necessary to use this script with"
" private models)."
),
)
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index 11463943c448..51de29a71a47 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -76,7 +77,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
@@ -593,7 +594,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
@@ -793,17 +794,22 @@ def main():
)
# Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
)
@@ -814,7 +820,7 @@ def main():
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
)
- # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
@@ -829,8 +835,14 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
+ if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -965,12 +977,12 @@ def main():
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
- accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
- index_no_updates
- ] = orig_embeds_params[index_no_updates]
- accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
- index_no_updates_2
- ] = orig_embeds_params_2[index_no_updates_2]
+ accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[index_no_updates] = (
+ orig_embeds_params[index_no_updates]
+ )
+ accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[index_no_updates_2] = (
+ orig_embeds_params_2[index_no_updates_2]
+ )
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md
index 2990b3abf3f5..6f8276a632f7 100644
--- a/examples/unconditional_image_generation/README.md
+++ b/examples/unconditional_image_generation/README.md
@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
+If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
+
Below, we explain both in more detail.
#### Provide the dataset as a folder
@@ -151,7 +153,7 @@ dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "pa
Next, push it to the hub!
```python
-# assuming you have ran the huggingface-cli login command in a terminal
+# assuming you have ran the hf auth login command in a terminal
dataset.push_to_hub("name_of_your_dataset")
# if you want to push to a private repo, simply pass private=True:
diff --git a/examples/unconditional_image_generation/test_unconditional.py b/examples/unconditional_image_generation/test_unconditional.py
index ef71da0a114c..94ea88881e52 100644
--- a/examples/unconditional_image_generation/test_unconditional.py
+++ b/examples/unconditional_image_generation/test_unconditional.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 45b674cb5894..0cc96220b932 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -29,7 +29,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res.expand(broadcast_shape)
+def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
+ """
+ if tensor.ndim == 2:
+ tensor = tensor.unsqueeze(0)
+ channels = tensor.shape[0]
+ if channels == 3:
+ return tensor
+ if channels == 1:
+ return tensor.repeat(3, 1, 1)
+ if channels == 2:
+ return torch.cat([tensor, tensor[:1]], dim=0)
+ if channels > 3:
+ return tensor[:3]
+ raise ValueError(f"Unsupported number of channels: {channels}")
+
+
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -260,6 +278,11 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
+ parser.add_argument(
+ "--preserve_input_precision",
+ action="store_true",
+ help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
+ )
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -453,19 +476,41 @@ def load_model_hook(models, input_dir):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation.
+ spatial_augmentations = [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ ]
+
augmentations = transforms.Compose(
- [
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
- transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
- transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ spatial_augmentations
+ + [
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
+ precision_augmentations = transforms.Compose(
+ [
+ transforms.PILToTensor(),
+ transforms.Lambda(_ensure_three_channels),
+ transforms.ConvertImageDtype(torch.float32),
+ ]
+ + spatial_augmentations
+ + [transforms.Normalize([0.5], [0.5])]
+ )
+
def transform_images(examples):
- images = [augmentations(image.convert("RGB")) for image in examples["image"]]
- return {"input": images}
+ processed = []
+ for image in examples["image"]:
+ if not args.preserve_input_precision:
+ processed.append(augmentations(image.convert("RGB")))
+ else:
+ precise_image = image
+ if precise_image.mode == "P":
+ precise_image = precise_image.convert("RGB")
+ processed.append(precision_augmentations(precise_image))
+ return {"input": processed}
logger.info(f"Dataset size: {len(dataset)}")
diff --git a/examples/vqgan/README.md b/examples/vqgan/README.md
index 056bbcaf6748..65bb044ea548 100644
--- a/examples/vqgan/README.md
+++ b/examples/vqgan/README.md
@@ -1,5 +1,5 @@
## Training an VQGAN VAE
-VQVAEs were first introduced in [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file).
+VQVAEs were first introduced in [Neural Discrete Representation Learning](https://huggingface.co/papers/1711.00937) and was combined with a GAN in the paper [Taming Transformers for High-Resolution Image Synthesis](https://huggingface.co/papers/2012.09841). The basic idea of a VQVAE is it's a type of a variational auto encoder with tokens as the latent space similar to tokens for LLMs. This script was adapted from a [pr to huggingface's open-muse project](https://github.com/huggingface/open-muse/pull/52) with general code following [lucidrian's implementation of the vqgan training script](https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/trainers.py) but both of these implementation follow from the [taming transformer repo](https://github.com/CompVis/taming-transformers?tab=readme-ov-file).
Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets).
diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py
index aa5d4c67b642..a3c8ee1e84b1 100644
--- a/examples/vqgan/test_vqgan.py
+++ b/examples/vqgan/test_vqgan.py
@@ -24,12 +24,18 @@
import torch
from diffusers import VQModel
-from diffusers.utils.testing_utils import require_timm
+# Add parent directories to path to import from tests
sys.path.append("..")
+repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
+if repo_root not in sys.path:
+ sys.path.insert(0, repo_root)
+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+from tests.testing_utils import require_timm # noqa
+
logging.basicConfig(level=logging.DEBUG)
@@ -177,7 +183,7 @@ def test_vqmodel_checkpointing(self):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir}
--seed=0
""".split()
@@ -262,7 +268,7 @@ def test_vqmodel_checkpointing_use_ema(self):
--model_config_name_or_path {vqmodel_config_path}
--discriminator_config_name_or_path {discriminator_config_path}
--checkpointing_steps=1
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--output_dir {tmpdir}
--use_ema
--seed=0
@@ -377,7 +383,7 @@ def test_vqmodel_checkpointing_checkpoints_total_limit_removes_multiple_checkpoi
--discriminator_config_name_or_path {discriminator_config_path}
--output_dir {tmpdir}
--checkpointing_steps=2
- --resume_from_checkpoint={os.path.join(tmpdir, 'checkpoint-4')}
+ --resume_from_checkpoint={os.path.join(tmpdir, "checkpoint-4")}
--checkpoints_total_limit=2
--seed=0
""".split()
diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py
index 992722fa7a78..eeb592a3f7d9 100644
--- a/examples/vqgan/train_vqgan.py
+++ b/examples/vqgan/train_vqgan.py
@@ -50,7 +50,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -653,15 +653,15 @@ def main():
try:
# Gets the resolution of the timm transformation after centercrop
timm_centercrop_transform = timm_transform.transforms[1]
- assert isinstance(
- timm_centercrop_transform, transforms.CenterCrop
- ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ assert isinstance(timm_centercrop_transform, transforms.CenterCrop), (
+ f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ )
timm_model_resolution = timm_centercrop_transform.size[0]
# Gets final normalization
timm_model_normalization = timm_transform.transforms[-1]
- assert isinstance(
- timm_model_normalization, transforms.Normalize
- ), f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ assert isinstance(timm_model_normalization, transforms.Normalize), (
+ f"Timm model {timm_model} is currently incompatible with this script. Try vgg19."
+ )
except AssertionError as e:
raise NotImplementedError(e)
# Enable flash attention if asked
diff --git a/pyproject.toml b/pyproject.toml
index 299865a1225d..fdda8a6977be 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,9 +1,12 @@
[tool.ruff]
line-length = 119
+extend-exclude = [
+ "src/diffusers/pipelines/flux2/system_messages.py",
+]
[tool.ruff.lint]
# Never enforce `E501` (line length violations).
-ignore = ["C901", "E501", "E741", "F402", "F823"]
+ignore = ["C901", "E501", "E721", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"]
# Ignore import violations in all `__init__.py` files.
diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py
index 21be29dfdb99..ddd1bf508b6d 100644
--- a/scripts/convert_amused.py
+++ b/scripts/convert_amused.py
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
# assert (old_output == new_output).all()
print("skipping full vae equivalence check")
- print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
+ print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
return new_vae
diff --git a/scripts/convert_consistency_decoder.py b/scripts/convert_consistency_decoder.py
index 629c784c095a..9e289457752b 100644
--- a/scripts/convert_consistency_decoder.py
+++ b/scripts/convert_consistency_decoder.py
@@ -24,7 +24,8 @@
def _extract_into_tensor(arr, timesteps, broadcast_shape):
- # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895
+ # """
res = arr[timesteps].float()
dims_to_append = len(broadcast_shape) - len(res.shape)
return res[(...,) + (None,) * dims_to_append]
@@ -507,7 +508,9 @@ def rename_state_dict(sd, embedding):
# encode with stable diffusion vae
-pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipe = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
+)
pipe.vae.cuda()
# construct original decoder with jitted model
@@ -1090,7 +1093,7 @@ def new_constructor(self, **kwargs):
Encoder.__init__ = new_constructor
-vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
consistency_vae = ConsistencyDecoderVAE(
encoder_args=vae.encoder.constructor_arguments,
decoder_args=unet.config,
@@ -1117,7 +1120,7 @@ def new_constructor(self, **kwargs):
print("running with diffusers pipeline")
pipe = DiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
)
pipe.to("cuda")
diff --git a/scripts/convert_consistency_to_diffusers.py b/scripts/convert_consistency_to_diffusers.py
index 0f8b4ddca8ef..2b918280ca05 100644
--- a/scripts/convert_consistency_to_diffusers.py
+++ b/scripts/convert_consistency_to_diffusers.py
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
- old_prefix = f"output_blocks.{current_layer-1}.1"
+ old_prefix = f"output_blocks.{current_layer - 1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1):
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
- old_prefix = f"output_blocks.{current_layer-1}.2"
+ old_prefix = f"output_blocks.{current_layer - 1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py
new file mode 100644
index 000000000000..6f6563ad641b
--- /dev/null
+++ b/scripts/convert_cosmos_to_diffusers.py
@@ -0,0 +1,506 @@
+import argparse
+import pathlib
+from typing import Any, Dict
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import snapshot_download
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from diffusers import (
+ AutoencoderKLCosmos,
+ AutoencoderKLWan,
+ Cosmos2TextToImagePipeline,
+ Cosmos2VideoToWorldPipeline,
+ CosmosTextToWorldPipeline,
+ CosmosTransformer3DModel,
+ CosmosVideoToWorldPipeline,
+ EDMEulerScheduler,
+ FlowMatchEulerDiscreteScheduler,
+)
+
+
+def remove_keys_(key: str, state_dict: Dict[str, Any]):
+ state_dict.pop(key)
+
+
+def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
+ block_index = int(key.split(".")[1].removeprefix("block"))
+ new_key = key
+
+ old_prefix = f"blocks.block{block_index}"
+ new_prefix = f"transformer_blocks.{block_index}"
+ new_key = new_prefix + new_key.removeprefix(old_prefix)
+
+ state_dict[new_key] = state_dict.pop(key)
+
+
+TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
+ "t_embedder.1": "time_embed.t_embedder",
+ "affline_norm": "time_embed.norm",
+ ".blocks.0.block.attn": ".attn1",
+ ".blocks.1.block.attn": ".attn2",
+ ".blocks.2.block": ".ff",
+ ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
+ ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
+ ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
+ ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
+ ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
+ ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
+ "to_q.0": "to_q",
+ "to_q.1": "norm_q",
+ "to_k.0": "to_k",
+ "to_k.1": "norm_k",
+ "to_v.0": "to_v",
+ "layer1": "net.0.proj",
+ "layer2": "net.2",
+ "proj.1": "proj",
+ "x_embedder": "patch_embed",
+ "extra_pos_embedder": "learnable_pos_embed",
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
+ "final_layer.adaLN_modulation.2": "norm_out.linear_2",
+ "final_layer.linear": "proj_out",
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
+ "blocks.block": rename_transformer_blocks_,
+ "logvar.0.freqs": remove_keys_,
+ "logvar.0.phases": remove_keys_,
+ "logvar.1.weight": remove_keys_,
+ "pos_embedder.seq": remove_keys_,
+}
+
+TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
+ "t_embedder.1": "time_embed.t_embedder",
+ "t_embedding_norm": "time_embed.norm",
+ "blocks": "transformer_blocks",
+ "adaln_modulation_self_attn.1": "norm1.linear_1",
+ "adaln_modulation_self_attn.2": "norm1.linear_2",
+ "adaln_modulation_cross_attn.1": "norm2.linear_1",
+ "adaln_modulation_cross_attn.2": "norm2.linear_2",
+ "adaln_modulation_mlp.1": "norm3.linear_1",
+ "adaln_modulation_mlp.2": "norm3.linear_2",
+ "self_attn": "attn1",
+ "cross_attn": "attn2",
+ "q_proj": "to_q",
+ "k_proj": "to_k",
+ "v_proj": "to_v",
+ "output_proj": "to_out.0",
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ "mlp.layer1": "ff.net.0.proj",
+ "mlp.layer2": "ff.net.2",
+ "x_embedder.proj.1": "patch_embed.proj",
+ "final_layer.adaln_modulation.1": "norm_out.linear_1",
+ "final_layer.adaln_modulation.2": "norm_out.linear_2",
+ "final_layer.linear": "proj_out",
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
+ "accum_video_sample_counter": remove_keys_,
+ "accum_image_sample_counter": remove_keys_,
+ "accum_iteration": remove_keys_,
+ "accum_train_in_hours": remove_keys_,
+ "pos_embedder.seq": remove_keys_,
+ "pos_embedder.dim_spatial_range": remove_keys_,
+ "pos_embedder.dim_temporal_range": remove_keys_,
+ "_extra_state": remove_keys_,
+}
+
+
+TRANSFORMER_CONFIGS = {
+ "Cosmos-1.0-Diffusion-7B-Text2World": {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 32,
+ "attention_head_dim": 128,
+ "num_layers": 28,
+ "mlp_ratio": 4.0,
+ "text_embed_dim": 1024,
+ "adaln_lora_dim": 256,
+ "max_size": (128, 240, 240),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (2.0, 1.0, 1.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": "learnable",
+ },
+ "Cosmos-1.0-Diffusion-7B-Video2World": {
+ "in_channels": 16 + 1,
+ "out_channels": 16,
+ "num_attention_heads": 32,
+ "attention_head_dim": 128,
+ "num_layers": 28,
+ "mlp_ratio": 4.0,
+ "text_embed_dim": 1024,
+ "adaln_lora_dim": 256,
+ "max_size": (128, 240, 240),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (2.0, 1.0, 1.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": "learnable",
+ },
+ "Cosmos-1.0-Diffusion-14B-Text2World": {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 40,
+ "attention_head_dim": 128,
+ "num_layers": 36,
+ "mlp_ratio": 4.0,
+ "text_embed_dim": 1024,
+ "adaln_lora_dim": 256,
+ "max_size": (128, 240, 240),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (2.0, 2.0, 2.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": "learnable",
+ },
+ "Cosmos-1.0-Diffusion-14B-Video2World": {
+ "in_channels": 16 + 1,
+ "out_channels": 16,
+ "num_attention_heads": 40,
+ "attention_head_dim": 128,
+ "num_layers": 36,
+ "mlp_ratio": 4.0,
+ "text_embed_dim": 1024,
+ "adaln_lora_dim": 256,
+ "max_size": (128, 240, 240),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (2.0, 2.0, 2.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": "learnable",
+ },
+ "Cosmos-2.0-Diffusion-2B-Text2Image": {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 16,
+ "attention_head_dim": 128,
+ "num_layers": 28,
+ "mlp_ratio": 4.0,
+ "text_embed_dim": 1024,
+ "adaln_lora_dim": 256,
+ "max_size": (128, 240, 240),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (1.0, 4.0, 4.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": None,
+ },
+ "Cosmos-2.0-Diffusion-14B-Text2Image": {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 40,
+ "attention_head_dim": 128,
+ "num_layers": 36,
+ "mlp_ratio": 4.0,
+ "text_embed_dim": 1024,
+ "adaln_lora_dim": 256,
+ "max_size": (128, 240, 240),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (1.0, 4.0, 4.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": None,
+ },
+ "Cosmos-2.0-Diffusion-2B-Video2World": {
+ "in_channels": 16 + 1,
+ "out_channels": 16,
+ "num_attention_heads": 16,
+ "attention_head_dim": 128,
+ "num_layers": 28,
+ "mlp_ratio": 4.0,
+ "text_embed_dim": 1024,
+ "adaln_lora_dim": 256,
+ "max_size": (128, 240, 240),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (1.0, 3.0, 3.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": None,
+ },
+ "Cosmos-2.0-Diffusion-14B-Video2World": {
+ "in_channels": 16 + 1,
+ "out_channels": 16,
+ "num_attention_heads": 40,
+ "attention_head_dim": 128,
+ "num_layers": 36,
+ "mlp_ratio": 4.0,
+ "text_embed_dim": 1024,
+ "adaln_lora_dim": 256,
+ "max_size": (128, 240, 240),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (20 / 24, 2.0, 2.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": None,
+ },
+}
+
+VAE_KEYS_RENAME_DICT = {
+ "down.0": "down_blocks.0",
+ "down.1": "down_blocks.1",
+ "down.2": "down_blocks.2",
+ "up.0": "up_blocks.2",
+ "up.1": "up_blocks.1",
+ "up.2": "up_blocks.0",
+ ".block.": ".resnets.",
+ "downsample": "downsamplers.0",
+ "upsample": "upsamplers.0",
+ "mid.block_1": "mid_block.resnets.0",
+ "mid.attn_1.0": "mid_block.attentions.0",
+ "mid.attn_1.1": "mid_block.temp_attentions.0",
+ "mid.block_2": "mid_block.resnets.1",
+ ".q.conv3d": ".to_q",
+ ".k.conv3d": ".to_k",
+ ".v.conv3d": ".to_v",
+ ".proj_out.conv3d": ".to_out.0",
+ ".0.conv3d": ".conv_s",
+ ".1.conv3d": ".conv_t",
+ "conv1.conv3d": "conv1",
+ "conv2.conv3d": "conv2",
+ "conv3.conv3d": "conv3",
+ "nin_shortcut.conv3d": "conv_shortcut",
+ "quant_conv.conv3d": "quant_conv",
+ "post_quant_conv.conv3d": "post_quant_conv",
+}
+
+VAE_SPECIAL_KEYS_REMAP = {
+ "wavelets": remove_keys_,
+ "_arange": remove_keys_,
+ "patch_size_buffer": remove_keys_,
+}
+
+VAE_CONFIGS = {
+ "CV8x8x8-0.1": {
+ "name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
+ "diffusers_config": {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 16,
+ "encoder_block_out_channels": (128, 256, 512, 512),
+ "decode_block_out_channels": (256, 512, 512, 512),
+ "attention_resolutions": (32,),
+ "resolution": 1024,
+ "num_layers": 2,
+ "patch_size": 4,
+ "patch_type": "haar",
+ "scaling_factor": 1.0,
+ "spatial_compression_ratio": 8,
+ "temporal_compression_ratio": 8,
+ "latents_mean": None,
+ "latents_std": None,
+ },
+ },
+ "CV8x8x8-1.0": {
+ "name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
+ "diffusers_config": {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 16,
+ "encoder_block_out_channels": (128, 256, 512, 512),
+ "decode_block_out_channels": (256, 512, 512, 512),
+ "attention_resolutions": (32,),
+ "resolution": 1024,
+ "num_layers": 2,
+ "patch_size": 4,
+ "patch_type": "haar",
+ "scaling_factor": 1.0,
+ "spatial_compression_ratio": 8,
+ "temporal_compression_ratio": 8,
+ "latents_mean": None,
+ "latents_std": None,
+ },
+ },
+}
+
+
+def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
+ state_dict = saved_dict
+ if "model" in saved_dict.keys():
+ state_dict = state_dict["model"]
+ if "module" in saved_dict.keys():
+ state_dict = state_dict["module"]
+ if "state_dict" in saved_dict.keys():
+ state_dict = state_dict["state_dict"]
+ return state_dict
+
+
+def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: bool = True):
+ PREFIX_KEY = "net."
+ original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=weights_only))
+
+ if "Cosmos-1.0" in transformer_type:
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
+ elif "Cosmos-2.0" in transformer_type:
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
+ else:
+ assert False
+
+ with init_empty_weights():
+ config = TRANSFORMER_CONFIGS[transformer_type]
+ transformer = CosmosTransformer3DModel(**config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ if new_key.startswith(PREFIX_KEY):
+ new_key = new_key.removeprefix(PREFIX_KEY)
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def convert_vae(vae_type: str):
+ model_name = VAE_CONFIGS[vae_type]["name"]
+ snapshot_directory = snapshot_download(model_name, repo_type="model")
+ directory = pathlib.Path(snapshot_directory)
+
+ autoencoder_file = directory / "autoencoder.jit"
+ mean_std_file = directory / "mean_std.pt"
+
+ original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
+ if mean_std_file.exists():
+ mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
+ else:
+ mean_std = (None, None)
+
+ config = VAE_CONFIGS[vae_type]["diffusers_config"]
+ config.update(
+ {
+ "latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
+ "latents_std": mean_std[1].detach().cpu().numpy().tolist(),
+ }
+ )
+ vae = AutoencoderKLCosmos(**config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ vae.load_state_dict(original_state_dict, strict=True, assign=True)
+ return vae
+
+
+def save_pipeline_cosmos_1_0(args, transformer, vae):
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
+ tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
+ # The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
+ # So, the sigma_min values that is used is the default value of 0.002.
+ scheduler = EDMEulerScheduler(
+ sigma_min=0.002,
+ sigma_max=80,
+ sigma_data=0.5,
+ sigma_schedule="karras",
+ num_train_timesteps=1000,
+ prediction_type="epsilon",
+ rho=7.0,
+ final_sigmas_type="sigma_min",
+ )
+
+ pipe_cls = CosmosTextToWorldPipeline if "Text2World" in args.transformer_type else CosmosVideoToWorldPipeline
+ pipe = pipe_cls(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ safety_checker=lambda *args, **kwargs: None,
+ )
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+
+
+def save_pipeline_cosmos_2_0(args, transformer, vae):
+ text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.bfloat16)
+ tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
+
+ scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
+
+ pipe_cls = Cosmos2TextToImagePipeline if "Text2Image" in args.transformer_type else Cosmos2VideoToWorldPipeline
+ pipe = pipe_cls(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ safety_checker=lambda *args, **kwargs: None,
+ )
+ pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
+ parser.add_argument(
+ "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
+ )
+ parser.add_argument(
+ "--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
+ )
+ parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
+ parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
+ parser.add_argument("--save_pipeline", action="store_true")
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
+ parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ dtype = DTYPE_MAPPING[args.dtype]
+
+ if args.save_pipeline:
+ assert args.transformer_ckpt_path is not None
+ assert args.vae_type is not None
+ assert args.text_encoder_path is not None
+ assert args.tokenizer_path is not None
+
+ if args.transformer_ckpt_path is not None:
+ weights_only = "Cosmos-1.0" in args.transformer_type
+ transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path, weights_only)
+ transformer = transformer.to(dtype=dtype)
+ if not args.save_pipeline:
+ transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+
+ if args.vae_type is not None:
+ if "Cosmos-1.0" in args.transformer_type:
+ vae = convert_vae(args.vae_type)
+ else:
+ vae = AutoencoderKLWan.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
+ )
+ if not args.save_pipeline:
+ vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+
+ if args.save_pipeline:
+ if "Cosmos-1.0" in args.transformer_type:
+ save_pipeline_cosmos_1_0(args, transformer, vae)
+ elif "Cosmos-2.0" in args.transformer_type:
+ save_pipeline_cosmos_2_0(args, transformer, vae)
+ else:
+ assert False
diff --git a/scripts/convert_dance_diffusion_to_diffusers.py b/scripts/convert_dance_diffusion_to_diffusers.py
index ce69bfe2bfc8..e269a49070cc 100755
--- a/scripts/convert_dance_diffusion_to_diffusers.py
+++ b/scripts/convert_dance_diffusion_to_diffusers.py
@@ -11,6 +11,7 @@
from torch import nn
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
+from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
MODELS_MAP = {
@@ -74,7 +75,7 @@ def __init__(self, global_args):
def download(model_name):
url = MODELS_MAP[model_name]["url"]
- r = requests.get(url, stream=True)
+ r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
local_filename = f"./{model_name}.ckpt"
with open(local_filename, "wb") as fp:
@@ -260,9 +261,9 @@ def main(args):
model_name = args.model_path.split("/")[-1].split(".")[0]
if not os.path.isfile(args.model_path):
- assert (
- model_name == args.model_path
- ), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ assert model_name == args.model_path, (
+ f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
+ )
args.model_path = download(model_name)
sample_rate = MODELS_MAP[model_name]["sample_rate"]
@@ -289,9 +290,9 @@ def main(args):
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
for key, value in renamed_state_dict.items():
- assert (
- diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
- ), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
+ f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
+ )
if key == "time_proj.weight":
value = value.squeeze()
diff --git a/scripts/convert_diffusers_to_original_sdxl.py b/scripts/convert_diffusers_to_original_sdxl.py
index 648d0376f72e..1aa792b3f06a 100644
--- a/scripts/convert_diffusers_to_original_sdxl.py
+++ b/scripts/convert_diffusers_to_original_sdxl.py
@@ -52,18 +52,18 @@
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i > 0:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(4):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i < 2:
@@ -75,12 +75,12 @@
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
@@ -89,7 +89,7 @@
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -137,20 +137,20 @@ def convert_unet_state_dict(unet_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"up.{3-i}.upsample."
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
- sd_mid_res_prefix = f"mid.block_{i+1}."
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py
index d1b7df070c43..049dda7d42a7 100644
--- a/scripts/convert_diffusers_to_original_stable_diffusion.py
+++ b/scripts/convert_diffusers_to_original_stable_diffusion.py
@@ -47,36 +47,36 @@
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
@@ -85,7 +85,7 @@
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
@@ -133,20 +133,20 @@ def convert_unet_state_dict(unet_state_dict):
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"up.{3-i}.upsample."
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
- sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{i}."
- sd_mid_res_prefix = f"mid.block_{i+1}."
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
diff --git a/scripts/convert_flux2_to_diffusers.py b/scripts/convert_flux2_to_diffusers.py
new file mode 100644
index 000000000000..2973913fa215
--- /dev/null
+++ b/scripts/convert_flux2_to_diffusers.py
@@ -0,0 +1,475 @@
+import argparse
+from contextlib import nullcontext
+from typing import Any, Dict, Tuple
+
+import safetensors.torch
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration
+
+from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+"""
+# VAE
+
+python scripts/convert_flux2_to_diffusers.py \
+--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \
+--vae_filename "flux2-vae.sft" \
+--output_path "/raid/yiyi/dummy-flux2-diffusers" \
+--vae
+
+# DiT
+
+python scripts/convert_flux2_to_diffusers.py \
+ --original_state_dict_repo_id diffusers-internal-dev/new-model-image \
+ --dit_filename flux-dev-dummy.sft \
+ --dit \
+ --output_path .
+
+# Full pipe
+
+python scripts/convert_flux2_to_diffusers.py \
+ --original_state_dict_repo_id diffusers-internal-dev/new-model-image \
+ --dit_filename flux-dev-dummy.sft \
+ --vae_filename "flux2-vae.sft" \
+ --dit --vae --full_pipe \
+ --output_path .
+"""
+
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
+parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
+parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
+parser.add_argument("--vae", action="store_true")
+parser.add_argument("--dit", action="store_true")
+parser.add_argument("--vae_dtype", type=str, default="fp32")
+parser.add_argument("--dit_dtype", type=str, default="bf16")
+parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--full_pipe", action="store_true")
+parser.add_argument("--output_path", type=str)
+
+args = parser.parse_args()
+
+
+def load_original_checkpoint(args, filename):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+DIFFUSERS_VAE_TO_FLUX2_MAPPING = {
+ "encoder.conv_in.weight": "encoder.conv_in.weight",
+ "encoder.conv_in.bias": "encoder.conv_in.bias",
+ "encoder.conv_out.weight": "encoder.conv_out.weight",
+ "encoder.conv_out.bias": "encoder.conv_out.bias",
+ "encoder.conv_norm_out.weight": "encoder.norm_out.weight",
+ "encoder.conv_norm_out.bias": "encoder.norm_out.bias",
+ "decoder.conv_in.weight": "decoder.conv_in.weight",
+ "decoder.conv_in.bias": "decoder.conv_in.bias",
+ "decoder.conv_out.weight": "decoder.conv_out.weight",
+ "decoder.conv_out.bias": "decoder.conv_out.bias",
+ "decoder.conv_norm_out.weight": "decoder.norm_out.weight",
+ "decoder.conv_norm_out.bias": "decoder.norm_out.bias",
+ "quant_conv.weight": "encoder.quant_conv.weight",
+ "quant_conv.bias": "encoder.quant_conv.bias",
+ "post_quant_conv.weight": "decoder.post_quant_conv.weight",
+ "post_quant_conv.bias": "decoder.post_quant_conv.bias",
+ "bn.running_mean": "bn.running_mean",
+ "bn.running_var": "bn.running_var",
+}
+
+
+# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
+ for ldm_key in keys:
+ diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+
+
+def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
+ for ldm_key in keys:
+ diffusers_key = (
+ ldm_key.replace(mapping["old"], mapping["new"])
+ .replace("norm.weight", "group_norm.weight")
+ .replace("norm.bias", "group_norm.bias")
+ .replace("q.weight", "to_q.weight")
+ .replace("q.bias", "to_q.bias")
+ .replace("k.weight", "to_k.weight")
+ .replace("k.bias", "to_k.bias")
+ .replace("v.weight", "to_v.weight")
+ .replace("v.bias", "to_v.bias")
+ .replace("proj_out.weight", "to_out.0.weight")
+ .replace("proj_out.bias", "to_out.0.bias")
+ )
+ new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ shape = new_checkpoint[diffusers_key].shape
+
+ if len(shape) == 3:
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
+ elif len(shape) == 4:
+ new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
+
+
+def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config):
+ new_checkpoint = {}
+ for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items():
+ if ldm_key not in vae_state_dict:
+ continue
+ new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len(config["down_block_types"])
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
+ )
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
+ )
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ update_vae_attentions_ldm_to_diffusers(
+ mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ )
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len(config["up_block_types"])
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
+ )
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+ update_vae_resnet_ldm_to_diffusers(
+ resnets,
+ new_checkpoint,
+ vae_state_dict,
+ mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
+ )
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ update_vae_attentions_ldm_to_diffusers(
+ mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ )
+ conv_attn_to_linear(new_checkpoint)
+
+ return new_checkpoint
+
+
+FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
+ # Image and text input projections
+ "img_in": "x_embedder",
+ "txt_in": "context_embedder",
+ # Timestep and guidance embeddings
+ "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
+ "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
+ "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
+ "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
+ # Modulation parameters
+ "double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
+ "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
+ "single_stream_modulation.lin": "single_stream_modulation.linear",
+ # Final output layer
+ # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
+ "final_layer.linear": "proj_out",
+}
+
+
+FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
+ "final_layer.adaLN_modulation.1": "norm_out.linear",
+}
+
+
+FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
+ # Handle fused QKV projections separately as we need to break into Q, K, V projections
+ "img_attn.norm.query_norm": "attn.norm_q",
+ "img_attn.norm.key_norm": "attn.norm_k",
+ "img_attn.proj": "attn.to_out.0",
+ "img_mlp.0": "ff.linear_in",
+ "img_mlp.2": "ff.linear_out",
+ "txt_attn.norm.query_norm": "attn.norm_added_q",
+ "txt_attn.norm.key_norm": "attn.norm_added_k",
+ "txt_attn.proj": "attn.to_add_out",
+ "txt_mlp.0": "ff_context.linear_in",
+ "txt_mlp.2": "ff_context.linear_out",
+}
+
+
+FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
+ "linear1": "attn.to_qkv_mlp_proj",
+ "linear2": "attn.to_out",
+ "norm.query_norm": "attn.norm_q",
+ "norm.key_norm": "attn.norm_k",
+}
+
+
+# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
+# diffusers implementation
+def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
+ # Skip if not a weight
+ if ".weight" not in key:
+ return
+
+ # If adaLN_modulation is in the key, swap scale and shift parameters
+ # Original implementation is (shift, scale); diffusers implementation is (scale, shift)
+ if "adaLN_modulation" in key:
+ key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
+ # Assume all such keys are in the AdaLayerNorm key map
+ new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
+ new_key = ".".join([new_key_without_param_type, param_type])
+
+ swapped_weight = swap_scale_shift(state_dict.pop(key))
+ state_dict[new_key] = swapped_weight
+ return
+
+
+def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
+ # Skip if not a weight, bias, or scale
+ if ".weight" not in key and ".bias" not in key and ".scale" not in key:
+ return
+
+ new_prefix = "transformer_blocks"
+ if "double_blocks." in key:
+ parts = key.split(".")
+ block_idx = parts[1]
+ modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
+ within_block_name = ".".join(parts[2:-1])
+ param_type = parts[-1]
+
+ if param_type == "scale":
+ param_type = "weight"
+
+ if "qkv" in within_block_name:
+ fused_qkv_weight = state_dict.pop(key)
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ if "img" in modality_block_name:
+ # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = "attn.to_q"
+ new_k_name = "attn.to_k"
+ new_v_name = "attn.to_v"
+ elif "txt" in modality_block_name:
+ # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = "attn.add_q_proj"
+ new_k_name = "attn.add_k_proj"
+ new_v_name = "attn.add_v_proj"
+ new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
+ new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
+ new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
+ state_dict[new_q_key] = to_q_weight
+ state_dict[new_k_key] = to_k_weight
+ state_dict[new_v_key] = to_v_weight
+ else:
+ new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
+ new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
+
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+ return
+
+
+def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
+ # Skip if not a weight, bias, or scale
+ if ".weight" not in key and ".bias" not in key and ".scale" not in key:
+ return
+
+ # Mapping:
+ # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
+ # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
+ # - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
+ # - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
+ new_prefix = "single_transformer_blocks"
+ if "single_blocks." in key:
+ parts = key.split(".")
+ block_idx = parts[1]
+ within_block_name = ".".join(parts[2:-1])
+ param_type = parts[-1]
+
+ if param_type == "scale":
+ param_type = "weight"
+
+ new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
+ new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
+
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+ return
+
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "adaLN_modulation": convert_ada_layer_norm_weights,
+ "double_blocks": convert_flux2_double_stream_blocks,
+ "single_blocks": convert_flux2_single_stream_blocks,
+}
+
+
+def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
+ if model_type == "test" or model_type == "dummy-flux2":
+ config = {
+ "model_id": "diffusers-internal-dev/dummy-flux2",
+ "diffusers_config": {
+ "patch_size": 1,
+ "in_channels": 128,
+ "num_layers": 8,
+ "num_single_layers": 48,
+ "attention_head_dim": 128,
+ "num_attention_heads": 48,
+ "joint_attention_dim": 15360,
+ "timestep_guidance_channels": 256,
+ "mlp_ratio": 3.0,
+ "axes_dims_rope": (32, 32, 32, 32),
+ "rope_theta": 2000,
+ "eps": 1e-6,
+ },
+ }
+ rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
+ special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
+ return config, rename_dict, special_keys_remap
+
+
+def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
+ config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
+
+ diffusers_config = config["diffusers_config"]
+
+ with init_empty_weights():
+ transformer = Flux2Transformer2DModel.from_config(diffusers_config)
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in rename_dict.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict(original_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in special_keys_remap.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def main(args):
+ if args.vae:
+ original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename)
+ vae = AutoencoderKLFlux2()
+ converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config)
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
+ if not args.full_pipe:
+ vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
+ vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae")
+
+ if args.dit:
+ original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
+ transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
+ if not args.full_pipe:
+ dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
+ transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ if args.full_pipe:
+ tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
+ text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
+ generate_config = GenerationConfig.from_pretrained(text_encoder_id)
+ generate_config.do_sample = True
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
+ text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16
+ )
+ tokenizer = AutoProcessor.from_pretrained(tokenizer_id)
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", subfolder="scheduler"
+ )
+
+ pipe = Flux2Pipeline(
+ vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
+ )
+ pipe.save_pretrained(args.output_path)
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_flux_to_diffusers.py b/scripts/convert_flux_to_diffusers.py
index fccac70dd855..ec31d842d4db 100644
--- a/scripts/convert_flux_to_diffusers.py
+++ b/scripts/convert_flux_to_diffusers.py
@@ -220,7 +220,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(
f"double_blocks.{i}.txt_attn.proj.bias"
)
- # single transfomer blocks
+ # single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
diff --git a/scripts/convert_hunyuan_image_to_diffusers.py b/scripts/convert_hunyuan_image_to_diffusers.py
new file mode 100644
index 000000000000..c41e934cc3d4
--- /dev/null
+++ b/scripts/convert_hunyuan_image_to_diffusers.py
@@ -0,0 +1,1044 @@
+import argparse
+import logging
+
+import torch
+from safetensors import safe_open
+
+from diffusers import AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
+
+
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+
+
+"""
+Usage examples
+==============
+
+python scripts/convert_hunyuan_image_to_diffusers.py \
+ --model_type hunyuanimage2.1 \
+ --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage2.1.safetensors" \
+ --vae_checkpoint_path "HunyuanImage-2.1/ckpts/vae/vae_2_1/pytorch_model.ckpt" \
+ --output_path "/raid/yiyi/test-hy21-diffusers" \
+ --dtype fp32
+
+python scripts/convert_hunyuan_image_to_diffusers.py \
+ --model_type hunyuanimage2.1-distilled \
+ --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage2.1-distilled.safetensors" \
+ --vae_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/vae/vae_2_1/pytorch_model.ckpt" \
+ --output_path "/raid/yiyi/test-hy21-distilled-diffusers" \
+ --dtype fp32
+
+
+python scripts/convert_hunyuan_image_to_diffusers.py \
+ --model_type hunyuanimage-refiner \
+ --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage-refiner.safetensors" \
+ --vae_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/vae/vae_refiner/pytorch_model.pt" \
+ --output_path "/raid/yiyi/test-hy2-refiner-diffusers" \
+ --dtype fp32
+"""
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--model_type", type=str, default=None
+) # hunyuanimage2.1, hunyuanimage2.1-distilled, hunyuanimage-refiner
+parser.add_argument("--transformer_checkpoint_path", default=None, type=str) # ckpts/dit/hunyuanimage2.1.safetensors
+parser.add_argument("--vae_checkpoint_path", default=None, type=str) # ckpts/vae/vae_2_1/pytorch_model.ckpt
+parser.add_argument("--output_path", type=str)
+parser.add_argument("--dtype", type=str, default="fp32")
+
+args = parser.parse_args()
+dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
+
+
+# copied from https://github.com/Tencent-Hunyuan/HunyuanImage-2.1/hyimage/models/hunyuan/modules/hunyuanimage_dit.py#L21
+def convert_hunyuan_dict_for_tensor_parallel(state_dict):
+ """
+ Convert a Hunyuan model state dict to be compatible with tensor parallel architectures.
+
+ Args:
+ state_dict: Original state dict
+
+ Returns:
+ new_dict: Converted state dict
+ """
+ new_dict = {}
+ for k, w in state_dict.items():
+ if k.startswith("double_blocks") and "attn_qkv.weight" in k:
+ hidden_size = w.shape[1]
+ k1 = k.replace("attn_qkv.weight", "attn_q.weight")
+ w1 = w[:hidden_size, :]
+ new_dict[k1] = w1
+ k2 = k.replace("attn_qkv.weight", "attn_k.weight")
+ w2 = w[hidden_size : 2 * hidden_size, :]
+ new_dict[k2] = w2
+ k3 = k.replace("attn_qkv.weight", "attn_v.weight")
+ w3 = w[-hidden_size:, :]
+ new_dict[k3] = w3
+ elif k.startswith("double_blocks") and "attn_qkv.bias" in k:
+ hidden_size = w.shape[0] // 3
+ k1 = k.replace("attn_qkv.bias", "attn_q.bias")
+ w1 = w[:hidden_size]
+ new_dict[k1] = w1
+ k2 = k.replace("attn_qkv.bias", "attn_k.bias")
+ w2 = w[hidden_size : 2 * hidden_size]
+ new_dict[k2] = w2
+ k3 = k.replace("attn_qkv.bias", "attn_v.bias")
+ w3 = w[-hidden_size:]
+ new_dict[k3] = w3
+ elif k.startswith("single_blocks") and "linear1" in k:
+ hidden_size = state_dict[k.replace("linear1", "linear2")].shape[0]
+ k1 = k.replace("linear1", "linear1_q")
+ w1 = w[:hidden_size]
+ new_dict[k1] = w1
+ k2 = k.replace("linear1", "linear1_k")
+ w2 = w[hidden_size : 2 * hidden_size]
+ new_dict[k2] = w2
+ k3 = k.replace("linear1", "linear1_v")
+ w3 = w[2 * hidden_size : 3 * hidden_size]
+ new_dict[k3] = w3
+ k4 = k.replace("linear1", "linear1_mlp")
+ w4 = w[3 * hidden_size :]
+ new_dict[k4] = w4
+ elif k.startswith("single_blocks") and "linear2" in k:
+ k1 = k.replace("linear2", "linear2.fc")
+ new_dict[k1] = w
+ else:
+ new_dict[k] = w
+ return new_dict
+
+
+def load_original_vae_checkpoint(args):
+ # "ckpts/vae/vae_2_1/pytorch_model.ckpt"
+ state_dict = torch.load(args.vae_checkpoint_path)
+
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ vae_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith("vae."):
+ vae_state_dict[k.replace("vae.", "")] = v
+
+ for k, v in vae_state_dict.items():
+ if "weight" in k:
+ if len(v.shape) == 5 and v.shape[2] == 1:
+ vae_state_dict[k] = v.squeeze(2)
+ else:
+ vae_state_dict[k] = v
+ else:
+ vae_state_dict[k] = v
+ return vae_state_dict
+
+
+def load_original_refiner_vae_checkpoint(args):
+ # "ckpts/vae/vae_refiner/pytorch_model.pt"
+ state_dict = torch.load(args.vae_checkpoint_path)
+
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ vae_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith("vae."):
+ vae_state_dict[k.replace("vae.", "")] = v
+ return vae_state_dict
+
+
+def load_original_transformer_checkpoint(args):
+ # ckpts/dit/hunyuanimage-refiner.safetensors"
+ # ckpts/dit/hunyuanimage2.1.safetensors"
+ state_dict = {}
+ with safe_open(args.transformer_checkpoint_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ state_dict[key] = f.get_tensor(key)
+ if args.model_type == "hunyuanimage-2.1":
+ state_dict = convert_hunyuan_dict_for_tensor_parallel(state_dict)
+ return state_dict
+
+
+def convert_hunyuan_image_transformer_checkpoint_to_diffusers(
+ original_state_dict, use_byt5=True, guidance_distilled=False, use_meanflow=False
+):
+ converted_state_dict = {}
+
+ # 1. byt5_in -> context_embedder_2
+ if use_byt5:
+ converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight")
+ converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias")
+ converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight")
+ converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias")
+ converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight")
+ converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias")
+ converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight")
+ converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias")
+
+ # 2. img_in -> x_embedder
+ converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight")
+ converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias")
+
+ # 3. txt_in -> context_embedder (complex mapping)
+ # txt_in.input_embedder -> context_embedder.proj_in
+ converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight")
+ converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
+
+ # txt_in.t_embedder -> context_embedder.time_text_embed.timestep_embedder
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
+ original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "txt_in.t_embedder.mlp.0.bias"
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
+ original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "txt_in.t_embedder.mlp.2.bias"
+ )
+
+ # txt_in.c_embedder -> context_embedder.time_text_embed.text_embedder
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_1.weight"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_1.bias"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_2.weight"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_2.bias"
+ )
+
+ # txt_in.individual_token_refiner -> context_embedder.token_refiner
+ for i in range(2): # 2 refiner blocks
+ block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}."
+ # norm1
+ converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.norm1.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.norm1.bias"
+ )
+ # norm2
+ converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.norm2.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.norm2.bias"
+ )
+
+ # Split QKV
+ qkv_weight = original_state_dict.pop(f"txt_in.individual_token_refiner.blocks.{i}.self_attn_qkv.weight")
+ qkv_bias = original_state_dict.pop(f"txt_in.individual_token_refiner.blocks.{i}.self_attn_qkv.bias")
+ q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
+
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q_weight
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k_weight
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v_weight
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
+
+ # attn projection
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.self_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.self_attn_proj.bias"
+ )
+
+ # MLP
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc2.bias"
+ )
+
+ # norm_out
+ converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.adaLN_modulation.1.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.adaLN_modulation.1.bias"
+ )
+
+ # 4. time_in -> time_text_embed.timestep_embedder
+ converted_state_dict["time_guidance_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_in.mlp.0.weight"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "time_in.mlp.0.bias"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_in.mlp.2.weight"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "time_in.mlp.2.bias"
+ )
+
+ # time_r_in -> time_guidance_embed.timestep_r_embedder
+ if use_meanflow:
+ converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop(
+ "time_r_in.mlp.0.weight"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop(
+ "time_r_in.mlp.0.bias"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop(
+ "time_r_in.mlp.2.weight"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop(
+ "time_r_in.mlp.2.bias"
+ )
+
+ # guidance_in -> time_guidance_embed.guidance_embedder
+ if guidance_distilled:
+ converted_state_dict["time_guidance_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
+ "guidance_in.mlp.0.weight"
+ )
+ converted_state_dict["time_guidance_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
+ "guidance_in.mlp.0.bias"
+ )
+ converted_state_dict["time_guidance_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
+ "guidance_in.mlp.2.weight"
+ )
+ converted_state_dict["time_guidance_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
+ "guidance_in.mlp.2.bias"
+ )
+
+ # 5. double_blocks -> transformer_blocks
+ for i in range(20): # 20 double blocks
+ block_prefix = f"transformer_blocks.{i}."
+
+ # norm1 (img_mod)
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.linear.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.linear.bias"
+ )
+
+ # norm1_context (txt_mod)
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.linear.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.linear.bias"
+ )
+
+ # img attention
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_q.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_q.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_k.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_k.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_v.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_v.bias"
+ )
+
+ # img attention norms
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_q_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_k_norm.weight"
+ )
+
+ # img attention projection
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_proj.bias"
+ )
+
+ # img MLP
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.fc2.bias"
+ )
+
+ # txt attention (additional projections)
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_q.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_q.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_k.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_k.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_v.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_v.bias"
+ )
+
+ # txt attention norms
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_q_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_k_norm.weight"
+ )
+
+ # txt attention projection
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_proj.bias"
+ )
+
+ # txt MLP (ff_context)
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.fc2.bias"
+ )
+
+ # 6. single_blocks -> single_transformer_blocks
+ for i in range(40): # 40 single blocks
+ block_prefix = f"single_transformer_blocks.{i}."
+
+ # norm
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.linear.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.linear.bias"
+ )
+
+ # attention Q, K, V
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_q.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_q.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_k.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_k.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_v.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_v.bias"
+ )
+
+ # attention norms
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.q_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.k_norm.weight"
+ )
+
+ # MLP projection
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_mlp.weight"
+ )
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_mlp.bias"
+ )
+
+ # output projection
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.fc.weight"
+ )
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.fc.bias"
+ )
+
+ # 7. final_layer -> norm_out + proj_out
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
+ shift_w, scale_w = original_state_dict.pop("final_layer.adaLN_modulation.1.weight").chunk(2, dim=0)
+ shift_b, scale_b = original_state_dict.pop("final_layer.adaLN_modulation.1.bias").chunk(2, dim=0)
+ converted_state_dict["norm_out.linear.weight"] = torch.cat([scale_w, shift_w], dim=0)
+ converted_state_dict["norm_out.linear.bias"] = torch.cat([scale_b, shift_b], dim=0)
+
+ return converted_state_dict, original_state_dict
+
+
+def convert_hunyuan_image_vae_checkpoint_to_diffusers(
+ original_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
+):
+ """Convert original VAE state dict to Diffusers format."""
+ converted = {}
+
+ # 1. Encoder
+ # 1.1 conv_in
+ converted["encoder.conv_in.weight"] = original_state_dict.pop("encoder.conv_in.weight")
+ converted["encoder.conv_in.bias"] = original_state_dict.pop("encoder.conv_in.bias")
+
+ # 1.2 down blocks
+ diffusers_block_idx = 0
+
+ for block_index in range(len(block_out_channels)):
+ for resnet_block_index in range(layers_per_block):
+ orig_prefix = f"encoder.down.{block_index}.block.{resnet_block_index}"
+ diff_prefix = f"encoder.down_blocks.{diffusers_block_idx}"
+
+ # resnet blocks
+ converted[f"{diff_prefix}.norm1.weight"] = original_state_dict.pop(f"{orig_prefix}.norm1.weight")
+ converted[f"{diff_prefix}.norm1.bias"] = original_state_dict.pop(f"{orig_prefix}.norm1.bias")
+ converted[f"{diff_prefix}.conv1.weight"] = original_state_dict.pop(f"{orig_prefix}.conv1.weight")
+ converted[f"{diff_prefix}.conv1.bias"] = original_state_dict.pop(f"{orig_prefix}.conv1.bias")
+ converted[f"{diff_prefix}.norm2.weight"] = original_state_dict.pop(f"{orig_prefix}.norm2.weight")
+ converted[f"{diff_prefix}.norm2.bias"] = original_state_dict.pop(f"{orig_prefix}.norm2.bias")
+ converted[f"{diff_prefix}.conv2.weight"] = original_state_dict.pop(f"{orig_prefix}.conv2.weight")
+ converted[f"{diff_prefix}.conv2.bias"] = original_state_dict.pop(f"{orig_prefix}.conv2.bias")
+
+ diffusers_block_idx += 1
+
+ # downsample blocks
+ if f"encoder.down.{block_index}.downsample.conv.weight" in original_state_dict:
+ converted[f"encoder.down_blocks.{diffusers_block_idx}.conv.weight"] = original_state_dict.pop(
+ f"encoder.down.{block_index}.downsample.conv.weight"
+ )
+ converted[f"encoder.down_blocks.{diffusers_block_idx}.conv.bias"] = original_state_dict.pop(
+ f"encoder.down.{block_index}.downsample.conv.bias"
+ )
+ diffusers_block_idx += 1
+
+ # 1.3 mid block
+ converted["encoder.mid_block.resnets.0.norm1.weight"] = original_state_dict.pop("encoder.mid.block_1.norm1.weight")
+ converted["encoder.mid_block.resnets.0.norm1.bias"] = original_state_dict.pop("encoder.mid.block_1.norm1.bias")
+ converted["encoder.mid_block.resnets.0.conv1.weight"] = original_state_dict.pop("encoder.mid.block_1.conv1.weight")
+ converted["encoder.mid_block.resnets.0.conv1.bias"] = original_state_dict.pop("encoder.mid.block_1.conv1.bias")
+ converted["encoder.mid_block.resnets.0.norm2.weight"] = original_state_dict.pop("encoder.mid.block_1.norm2.weight")
+ converted["encoder.mid_block.resnets.0.norm2.bias"] = original_state_dict.pop("encoder.mid.block_1.norm2.bias")
+ converted["encoder.mid_block.resnets.0.conv2.weight"] = original_state_dict.pop("encoder.mid.block_1.conv2.weight")
+ converted["encoder.mid_block.resnets.0.conv2.bias"] = original_state_dict.pop("encoder.mid.block_1.conv2.bias")
+
+ converted["encoder.mid_block.resnets.1.norm1.weight"] = original_state_dict.pop("encoder.mid.block_2.norm1.weight")
+ converted["encoder.mid_block.resnets.1.norm1.bias"] = original_state_dict.pop("encoder.mid.block_2.norm1.bias")
+ converted["encoder.mid_block.resnets.1.conv1.weight"] = original_state_dict.pop("encoder.mid.block_2.conv1.weight")
+ converted["encoder.mid_block.resnets.1.conv1.bias"] = original_state_dict.pop("encoder.mid.block_2.conv1.bias")
+ converted["encoder.mid_block.resnets.1.norm2.weight"] = original_state_dict.pop("encoder.mid.block_2.norm2.weight")
+ converted["encoder.mid_block.resnets.1.norm2.bias"] = original_state_dict.pop("encoder.mid.block_2.norm2.bias")
+ converted["encoder.mid_block.resnets.1.conv2.weight"] = original_state_dict.pop("encoder.mid.block_2.conv2.weight")
+ converted["encoder.mid_block.resnets.1.conv2.bias"] = original_state_dict.pop("encoder.mid.block_2.conv2.bias")
+
+ converted["encoder.mid_block.attentions.0.norm.weight"] = original_state_dict.pop("encoder.mid.attn_1.norm.weight")
+ converted["encoder.mid_block.attentions.0.norm.bias"] = original_state_dict.pop("encoder.mid.attn_1.norm.bias")
+ converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
+ converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
+ converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
+ converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
+ converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
+ converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
+ converted["encoder.mid_block.attentions.0.proj.weight"] = original_state_dict.pop(
+ "encoder.mid.attn_1.proj_out.weight"
+ )
+ converted["encoder.mid_block.attentions.0.proj.bias"] = original_state_dict.pop("encoder.mid.attn_1.proj_out.bias")
+
+ # 1.4 encoder output
+ converted["encoder.norm_out.weight"] = original_state_dict.pop("encoder.norm_out.weight")
+ converted["encoder.norm_out.bias"] = original_state_dict.pop("encoder.norm_out.bias")
+ converted["encoder.conv_out.weight"] = original_state_dict.pop("encoder.conv_out.weight")
+ converted["encoder.conv_out.bias"] = original_state_dict.pop("encoder.conv_out.bias")
+
+ # 2. Decoder
+ # 2.1 conv_in
+ converted["decoder.conv_in.weight"] = original_state_dict.pop("decoder.conv_in.weight")
+ converted["decoder.conv_in.bias"] = original_state_dict.pop("decoder.conv_in.bias")
+
+ # 2.2 mid block
+ converted["decoder.mid_block.resnets.0.norm1.weight"] = original_state_dict.pop("decoder.mid.block_1.norm1.weight")
+ converted["decoder.mid_block.resnets.0.norm1.bias"] = original_state_dict.pop("decoder.mid.block_1.norm1.bias")
+ converted["decoder.mid_block.resnets.0.conv1.weight"] = original_state_dict.pop("decoder.mid.block_1.conv1.weight")
+ converted["decoder.mid_block.resnets.0.conv1.bias"] = original_state_dict.pop("decoder.mid.block_1.conv1.bias")
+ converted["decoder.mid_block.resnets.0.norm2.weight"] = original_state_dict.pop("decoder.mid.block_1.norm2.weight")
+ converted["decoder.mid_block.resnets.0.norm2.bias"] = original_state_dict.pop("decoder.mid.block_1.norm2.bias")
+ converted["decoder.mid_block.resnets.0.conv2.weight"] = original_state_dict.pop("decoder.mid.block_1.conv2.weight")
+ converted["decoder.mid_block.resnets.0.conv2.bias"] = original_state_dict.pop("decoder.mid.block_1.conv2.bias")
+
+ converted["decoder.mid_block.resnets.1.norm1.weight"] = original_state_dict.pop("decoder.mid.block_2.norm1.weight")
+ converted["decoder.mid_block.resnets.1.norm1.bias"] = original_state_dict.pop("decoder.mid.block_2.norm1.bias")
+ converted["decoder.mid_block.resnets.1.conv1.weight"] = original_state_dict.pop("decoder.mid.block_2.conv1.weight")
+ converted["decoder.mid_block.resnets.1.conv1.bias"] = original_state_dict.pop("decoder.mid.block_2.conv1.bias")
+ converted["decoder.mid_block.resnets.1.norm2.weight"] = original_state_dict.pop("decoder.mid.block_2.norm2.weight")
+ converted["decoder.mid_block.resnets.1.norm2.bias"] = original_state_dict.pop("decoder.mid.block_2.norm2.bias")
+ converted["decoder.mid_block.resnets.1.conv2.weight"] = original_state_dict.pop("decoder.mid.block_2.conv2.weight")
+ converted["decoder.mid_block.resnets.1.conv2.bias"] = original_state_dict.pop("decoder.mid.block_2.conv2.bias")
+
+ converted["decoder.mid_block.attentions.0.norm.weight"] = original_state_dict.pop("decoder.mid.attn_1.norm.weight")
+ converted["decoder.mid_block.attentions.0.norm.bias"] = original_state_dict.pop("decoder.mid.attn_1.norm.bias")
+ converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
+ converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
+ converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
+ converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
+ converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
+ converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
+ converted["decoder.mid_block.attentions.0.proj.weight"] = original_state_dict.pop(
+ "decoder.mid.attn_1.proj_out.weight"
+ )
+ converted["decoder.mid_block.attentions.0.proj.bias"] = original_state_dict.pop("decoder.mid.attn_1.proj_out.bias")
+
+ # 2.3 up blocks
+ diffusers_block_idx = 0
+ for up_block_index in range(len(block_out_channels)):
+ # resnet blocks
+ for resnet_block_index in range(layers_per_block + 1):
+ orig_prefix = f"decoder.up.{up_block_index}.block.{resnet_block_index}"
+ diff_prefix = f"decoder.up_blocks.{diffusers_block_idx}"
+
+ converted[f"{diff_prefix}.norm1.weight"] = original_state_dict.pop(f"{orig_prefix}.norm1.weight")
+ converted[f"{diff_prefix}.norm1.bias"] = original_state_dict.pop(f"{orig_prefix}.norm1.bias")
+ converted[f"{diff_prefix}.conv1.weight"] = original_state_dict.pop(f"{orig_prefix}.conv1.weight")
+ converted[f"{diff_prefix}.conv1.bias"] = original_state_dict.pop(f"{orig_prefix}.conv1.bias")
+ converted[f"{diff_prefix}.norm2.weight"] = original_state_dict.pop(f"{orig_prefix}.norm2.weight")
+ converted[f"{diff_prefix}.norm2.bias"] = original_state_dict.pop(f"{orig_prefix}.norm2.bias")
+ converted[f"{diff_prefix}.conv2.weight"] = original_state_dict.pop(f"{orig_prefix}.conv2.weight")
+ converted[f"{diff_prefix}.conv2.bias"] = original_state_dict.pop(f"{orig_prefix}.conv2.bias")
+
+ diffusers_block_idx += 1
+
+ # upsample blocks
+ if f"decoder.up.{up_block_index}.upsample.conv.weight" in original_state_dict:
+ converted[f"decoder.up_blocks.{diffusers_block_idx}.conv.weight"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.weight"
+ )
+ converted[f"decoder.up_blocks.{diffusers_block_idx}.conv.bias"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.bias"
+ )
+ diffusers_block_idx += 1
+
+ # 2.4 decoder output
+ converted["decoder.norm_out.weight"] = original_state_dict.pop("decoder.norm_out.weight")
+ converted["decoder.norm_out.bias"] = original_state_dict.pop("decoder.norm_out.bias")
+ converted["decoder.conv_out.weight"] = original_state_dict.pop("decoder.conv_out.weight")
+ converted["decoder.conv_out.bias"] = original_state_dict.pop("decoder.conv_out.bias")
+
+ return converted, original_state_dict
+
+
+def convert_hunyuan_image_refiner_vae_checkpoint_to_diffusers(
+ original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2
+):
+ converted = {}
+
+ # 1. Encoder
+ # 1.1 conv_in
+ converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight")
+ converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias")
+
+ # 1.2 Down blocks
+ for down_block_index in range(len(block_out_channels)): # 0 to 4
+ # ResNet blocks
+ for resnet_block_index in range(layers_per_block): # 0 to 1
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
+ original_state_dict.pop(
+ f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight"
+ )
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
+ original_state_dict.pop(
+ f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight"
+ )
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias")
+ )
+
+ # Downsample (if exists)
+ if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict:
+ converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias")
+ )
+
+ # 1.3 Mid block
+ converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma")
+ converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv1.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv1.conv.bias"
+ )
+ converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma")
+ converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv2.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv2.conv.bias"
+ )
+
+ converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma")
+ converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv1.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv1.conv.bias"
+ )
+ converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma")
+ converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv2.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv2.conv.bias"
+ )
+
+ # Attention block
+ converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma")
+ converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
+ converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
+ converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
+ converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
+ converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
+ converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
+ converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
+ "encoder.mid.attn_1.proj_out.weight"
+ )
+ converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
+ "encoder.mid.attn_1.proj_out.bias"
+ )
+
+ # 1.4 Encoder output
+ converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma")
+ converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight")
+ converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias")
+
+ # 2. Decoder
+ # 2.1 conv_in
+ converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight")
+ converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias")
+
+ # 2.2 Mid block
+ converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma")
+ converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv1.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv1.conv.bias"
+ )
+ converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma")
+ converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv2.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv2.conv.bias"
+ )
+
+ converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma")
+ converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv1.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv1.conv.bias"
+ )
+ converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma")
+ converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv2.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv2.conv.bias"
+ )
+
+ # Decoder attention block
+ converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma")
+ converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
+ converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
+ converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
+ converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
+ converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
+ converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
+ converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
+ "decoder.mid.attn_1.proj_out.weight"
+ )
+ converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
+ "decoder.mid.attn_1.proj_out.bias"
+ )
+
+ # 2.3 Up blocks
+ for up_block_index in range(len(block_out_channels)): # 0 to 5
+ # ResNet blocks
+ for resnet_block_index in range(layers_per_block + 1): # 0 to 2 (decoder has 3 resnets per level)
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias")
+ )
+
+ # Upsample (if exists)
+ if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict:
+ converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.conv.weight"
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.conv.bias"
+ )
+
+ # 2.4 Decoder output
+ converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma")
+ converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight")
+ converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias")
+
+ return converted, original_state_dict
+
+
+def main(args):
+ if args.model_type == "hunyuanimage2.1":
+ original_transformer_state_dict = load_original_transformer_checkpoint(args)
+ original_vae_state_dict = load_original_vae_checkpoint(args)
+
+ transformer_config = {
+ "in_channels": 64,
+ "out_channels": 64,
+ "num_attention_heads": 28,
+ "attention_head_dim": 128,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "num_refiner_layers": 2,
+ "patch_size": (1, 1),
+ "qk_norm": "rms_norm",
+ "guidance_embeds": False,
+ "text_embed_dim": 3584,
+ "text_embed_2_dim": 1472,
+ "rope_theta": 256.0,
+ "rope_axes_dim": (64, 64),
+ }
+
+ converted_transformer_state_dict, original_transformer_state_dict = (
+ convert_hunyuan_image_transformer_checkpoint_to_diffusers(
+ original_transformer_state_dict, use_byt5=True, guidance_distilled=False
+ )
+ )
+
+ if original_transformer_state_dict:
+ logger.warning(
+ f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
+ )
+
+ transformer = HunyuanImageTransformer2DModel(**transformer_config)
+ missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ if missing_keys:
+ logger.warning(f"Missing keys for transformer: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
+
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ vae_config_diffusers = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 64,
+ "block_out_channels": [128, 256, 512, 512, 1024, 1024],
+ "layers_per_block": 2,
+ "spatial_compression_ratio": 32,
+ "sample_size": 384,
+ "scaling_factor": 0.75289,
+ "downsample_match_channel": True,
+ "upsample_match_channel": True,
+ }
+ converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_vae_checkpoint_to_diffusers(
+ original_vae_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
+ )
+ if original_vae_state_dict:
+ logger.warning(
+ f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
+ )
+
+ vae = AutoencoderKLHunyuanImage(**vae_config_diffusers)
+ missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
+
+ if missing_keys:
+ logger.warning(f"Missing keys for vae: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for vae: {unexpected_key}")
+
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+ elif args.model_type == "hunyuanimage2.1-distilled":
+ original_transformer_state_dict = load_original_transformer_checkpoint(args)
+ original_vae_state_dict = load_original_vae_checkpoint(args)
+
+ transformer_config = {
+ "in_channels": 64,
+ "out_channels": 64,
+ "num_attention_heads": 28,
+ "attention_head_dim": 128,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "num_refiner_layers": 2,
+ "patch_size": (1, 1),
+ "qk_norm": "rms_norm",
+ "guidance_embeds": True,
+ "text_embed_dim": 3584,
+ "text_embed_2_dim": 1472,
+ "rope_theta": 256.0,
+ "rope_axes_dim": (64, 64),
+ "use_meanflow": True,
+ }
+
+ converted_transformer_state_dict, original_transformer_state_dict = (
+ convert_hunyuan_image_transformer_checkpoint_to_diffusers(
+ original_transformer_state_dict, use_byt5=True, guidance_distilled=True, use_meanflow=True
+ )
+ )
+
+ if original_transformer_state_dict:
+ logger.warning(
+ f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
+ )
+
+ transformer = HunyuanImageTransformer2DModel(**transformer_config)
+ missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ if missing_keys:
+ logger.warning(f"Missing keys for transformer: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
+
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ vae_config_diffusers = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 64,
+ "block_out_channels": [128, 256, 512, 512, 1024, 1024],
+ "layers_per_block": 2,
+ "spatial_compression_ratio": 32,
+ "sample_size": 384,
+ "scaling_factor": 0.75289,
+ "downsample_match_channel": True,
+ "upsample_match_channel": True,
+ }
+ converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_vae_checkpoint_to_diffusers(
+ original_vae_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
+ )
+ if original_vae_state_dict:
+ logger.warning(
+ f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
+ )
+
+ vae = AutoencoderKLHunyuanImage(**vae_config_diffusers)
+ missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
+
+ if missing_keys:
+ logger.warning(f"Missing keys for vae: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for vae: {unexpected_key}")
+
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+ elif args.model_type == "hunyuanimage-refiner":
+ original_transformer_state_dict = load_original_transformer_checkpoint(args)
+ original_vae_state_dict = load_original_refiner_vae_checkpoint(args)
+
+ transformer_config = {
+ "in_channels": 128,
+ "out_channels": 64,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "rope_axes_dim": [16, 56, 56],
+ "num_attention_heads": 26,
+ "attention_head_dim": 128,
+ "mlp_ratio": 4,
+ "patch_size": (1, 1, 1),
+ "text_embed_dim": 3584,
+ "guidance_embeds": True,
+ }
+ converted_transformer_state_dict, original_transformer_state_dict = (
+ convert_hunyuan_image_transformer_checkpoint_to_diffusers(
+ original_transformer_state_dict, use_byt5=False, guidance_distilled=True
+ )
+ )
+ if original_transformer_state_dict:
+ logger.warning(
+ f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
+ )
+
+ transformer = HunyuanImageTransformer2DModel(**transformer_config)
+ missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+ if missing_keys:
+ logger.warning(f"Missing keys for transformer: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
+
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ vae = AutoencoderKLHunyuanImageRefiner()
+
+ converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_refiner_vae_checkpoint_to_diffusers(
+ original_vae_state_dict
+ )
+ if original_vae_state_dict:
+ logger.warning(
+ f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
+ )
+
+ missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
+ logger.warning(f"Missing keys for vae: {missing_keys}")
+ logger.warning(f"Unexpected keys for vae: {unexpected_key}")
+
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py
new file mode 100644
index 000000000000..89e5cdb16956
--- /dev/null
+++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py
@@ -0,0 +1,875 @@
+import argparse
+import json
+import os
+import pathlib
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download, snapshot_download
+from safetensors.torch import load_file
+from transformers import (
+ AutoModel,
+ AutoTokenizer,
+ SiglipImageProcessor,
+ SiglipVisionModel,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo15,
+ ClassifierFreeGuidance,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideo15ImageToVideoPipeline,
+ HunyuanVideo15Pipeline,
+ HunyuanVideo15Transformer3DModel,
+)
+
+
+# to convert only transformer
+"""
+python scripts/convert_hunyuan_video1_5_to_diffusers.py \
+ --original_state_dict_repo_id tencent/HunyuanVideo-1.5\
+ --output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers/transformer\
+ --transformer_type 480p_t2v
+"""
+
+# to convert full pipeline
+"""
+python scripts/convert_hunyuan_video1_5_to_diffusers.py \
+ --original_state_dict_repo_id tencent/HunyuanVideo-1.5\
+ --output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \
+ --save_pipeline \
+ --byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\
+ --transformer_type 480p_t2v
+"""
+
+
+TRANSFORMER_CONFIGS = {
+ "480p_t2v": {
+ "target_size": 640,
+ "task_type": "i2v",
+ },
+ "720p_t2v": {
+ "target_size": 960,
+ "task_type": "t2v",
+ },
+ "720p_i2v": {
+ "target_size": 960,
+ "task_type": "i2v",
+ },
+ "480p_t2v_distilled": {
+ "target_size": 640,
+ "task_type": "t2v",
+ },
+ "480p_i2v_distilled": {
+ "target_size": 640,
+ "task_type": "i2v",
+ },
+ "720p_i2v_distilled": {
+ "target_size": 960,
+ "task_type": "i2v",
+ },
+ "480p_i2v_step_distilled": {
+ "target_size": 640,
+ "task_type": "i2v",
+ "use_meanflow": True,
+ },
+}
+
+SCHEDULER_CONFIGS = {
+ "480p_t2v": {
+ "shift": 5.0,
+ },
+ "480p_i2v": {
+ "shift": 5.0,
+ },
+ "720p_t2v": {
+ "shift": 9.0,
+ },
+ "720p_i2v": {
+ "shift": 7.0,
+ },
+ "480p_t2v_distilled": {
+ "shift": 5.0,
+ },
+ "480p_i2v_distilled": {
+ "shift": 5.0,
+ },
+ "720p_i2v_distilled": {
+ "shift": 7.0,
+ },
+ "480p_i2v_step_distilled": {
+ "shift": 7.0,
+ },
+}
+
+GUIDANCE_CONFIGS = {
+ "480p_t2v": {
+ "guidance_scale": 6.0,
+ },
+ "480p_i2v": {
+ "guidance_scale": 6.0,
+ },
+ "720p_t2v": {
+ "guidance_scale": 6.0,
+ },
+ "720p_i2v": {
+ "guidance_scale": 6.0,
+ },
+ "480p_t2v_distilled": {
+ "guidance_scale": 1.0,
+ },
+ "480p_i2v_distilled": {
+ "guidance_scale": 1.0,
+ },
+ "720p_i2v_distilled": {
+ "guidance_scale": 1.0,
+ },
+ "480p_i2v_step_distilled": {
+ "guidance_scale": 1.0,
+ },
+}
+
+
+def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=None):
+ """
+ Convert HunyuanVideo 1.5 original checkpoint to Diffusers format.
+ """
+ converted_state_dict = {}
+
+ # 1. time_embed.timestep_embedder <- time_in
+ converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_in.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.mlp.0.bias")
+ converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_in.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias")
+
+ if config.use_meanflow:
+ converted_state_dict["time_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop(
+ "time_r_in.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop(
+ "time_r_in.mlp.0.bias"
+ )
+ converted_state_dict["time_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop(
+ "time_r_in.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop(
+ "time_r_in.mlp.2.bias"
+ )
+
+ # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
+ original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "txt_in.t_embedder.mlp.0.bias"
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
+ original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "txt_in.t_embedder.mlp.2.bias"
+ )
+
+ # 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_1.weight"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_1.bias"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_2.weight"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_2.bias"
+ )
+
+ # 4. context_embedder.proj_in <- txt_in.input_embedder
+ converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight")
+ converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
+
+ # 5. context_embedder.token_refiner <- txt_in.individual_token_refiner
+ num_refiner_blocks = 2
+ for i in range(num_refiner_blocks):
+ block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}."
+ orig_prefix = f"txt_in.individual_token_refiner.blocks.{i}."
+
+ # norm1
+ converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(f"{orig_prefix}norm1.weight")
+ converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(f"{orig_prefix}norm1.bias")
+
+ # Split self_attn_qkv into to_q, to_k, to_v
+ qkv_weight = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.weight")
+ qkv_bias = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.bias")
+ q, k, v = torch.chunk(qkv_weight, 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
+
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
+
+ # self_attn_proj -> attn.to_out.0
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"{orig_prefix}self_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"{orig_prefix}self_attn_proj.bias"
+ )
+
+ # norm2
+ converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(f"{orig_prefix}norm2.weight")
+ converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(f"{orig_prefix}norm2.bias")
+
+ # mlp -> ff
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"{orig_prefix}mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"{orig_prefix}mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"{orig_prefix}mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(f"{orig_prefix}mlp.fc2.bias")
+
+ # adaLN_modulation -> norm_out
+ converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop(
+ f"{orig_prefix}adaLN_modulation.1.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop(
+ f"{orig_prefix}adaLN_modulation.1.bias"
+ )
+
+ # 6. context_embedder_2 <- byt5_in
+ converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight")
+ converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias")
+ converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight")
+ converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias")
+ converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight")
+ converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias")
+ converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight")
+ converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias")
+
+ # 7. image_embedder <- vision_in
+ converted_state_dict["image_embedder.norm_in.weight"] = original_state_dict.pop("vision_in.proj.0.weight")
+ converted_state_dict["image_embedder.norm_in.bias"] = original_state_dict.pop("vision_in.proj.0.bias")
+ converted_state_dict["image_embedder.linear_1.weight"] = original_state_dict.pop("vision_in.proj.1.weight")
+ converted_state_dict["image_embedder.linear_1.bias"] = original_state_dict.pop("vision_in.proj.1.bias")
+ converted_state_dict["image_embedder.linear_2.weight"] = original_state_dict.pop("vision_in.proj.3.weight")
+ converted_state_dict["image_embedder.linear_2.bias"] = original_state_dict.pop("vision_in.proj.3.bias")
+ converted_state_dict["image_embedder.norm_out.weight"] = original_state_dict.pop("vision_in.proj.4.weight")
+ converted_state_dict["image_embedder.norm_out.bias"] = original_state_dict.pop("vision_in.proj.4.bias")
+
+ # 8. x_embedder <- img_in
+ converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight")
+ converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias")
+
+ # 9. cond_type_embed <- cond_type_embedding
+ converted_state_dict["cond_type_embed.weight"] = original_state_dict.pop("cond_type_embedding.weight")
+
+ # 10. transformer_blocks <- double_blocks
+ num_layers = 54
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ orig_prefix = f"double_blocks.{i}."
+
+ # norm1 (img_mod)
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_mod.linear.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
+ f"{orig_prefix}img_mod.linear.bias"
+ )
+
+ # norm1_context (txt_mod)
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_mod.linear.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
+ f"{orig_prefix}txt_mod.linear.bias"
+ )
+
+ # img attention (to_q, to_k, to_v)
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_q.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_q.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_k.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_k.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_v.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_v.bias"
+ )
+
+ # img attention qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_q_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_k_norm.weight"
+ )
+
+ # img attention output projection
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"{orig_prefix}img_attn_proj.bias"
+ )
+
+ # txt attention (add_q_proj, add_k_proj, add_v_proj)
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_q.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_q.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_k.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_k.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_v.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_v.bias"
+ )
+
+ # txt attention qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_q_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_k_norm.weight"
+ )
+
+ # txt attention output projection
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
+ f"{orig_prefix}txt_attn_proj.bias"
+ )
+
+ # norm2 and norm2_context (these don't have weights in the original, they're LayerNorm with elementwise_affine=False)
+ # So we skip them
+
+ # img_mlp -> ff
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"{orig_prefix}img_mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"{orig_prefix}img_mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"{orig_prefix}img_mlp.fc2.bias"
+ )
+
+ # txt_mlp -> ff_context
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
+ f"{orig_prefix}txt_mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
+ f"{orig_prefix}txt_mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
+ f"{orig_prefix}txt_mlp.fc2.bias"
+ )
+
+ # 11. norm_out and proj_out <- final_layer
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ )
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
+ )
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
+
+ return converted_state_dict
+
+
+def convert_hunyuan_video_15_vae_checkpoint_to_diffusers(
+ original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2
+):
+ converted = {}
+
+ # 1. Encoder
+ # 1.1 conv_in
+ converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight")
+ converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias")
+
+ # 1.2 Down blocks
+ for down_block_index in range(len(block_out_channels)): # 0 to 4
+ # ResNet blocks
+ for resnet_block_index in range(layers_per_block): # 0 to 1
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
+ original_state_dict.pop(
+ f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight"
+ )
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
+ original_state_dict.pop(
+ f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight"
+ )
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias")
+ )
+
+ # Downsample (if exists)
+ if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict:
+ converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias")
+ )
+
+ # 1.3 Mid block
+ converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma")
+ converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv1.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv1.conv.bias"
+ )
+ converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma")
+ converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv2.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv2.conv.bias"
+ )
+
+ converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma")
+ converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv1.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv1.conv.bias"
+ )
+ converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma")
+ converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv2.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv2.conv.bias"
+ )
+
+ # Attention block
+ converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma")
+ converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
+ converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
+ converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
+ converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
+ converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
+ converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
+ converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
+ "encoder.mid.attn_1.proj_out.weight"
+ )
+ converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
+ "encoder.mid.attn_1.proj_out.bias"
+ )
+
+ # 1.4 Encoder output
+ converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma")
+ converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight")
+ converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias")
+
+ # 2. Decoder
+ # 2.1 conv_in
+ converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight")
+ converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias")
+
+ # 2.2 Mid block
+ converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma")
+ converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv1.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv1.conv.bias"
+ )
+ converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma")
+ converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv2.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv2.conv.bias"
+ )
+
+ converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma")
+ converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv1.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv1.conv.bias"
+ )
+ converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma")
+ converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv2.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv2.conv.bias"
+ )
+
+ # Decoder attention block
+ converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma")
+ converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
+ converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
+ converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
+ converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
+ converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
+ converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
+ converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
+ "decoder.mid.attn_1.proj_out.weight"
+ )
+ converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
+ "decoder.mid.attn_1.proj_out.bias"
+ )
+
+ # 2.3 Up blocks
+ for up_block_index in range(len(block_out_channels)): # 0 to 5
+ # ResNet blocks
+ for resnet_block_index in range(layers_per_block + 1): # 0 to 2 (decoder has 3 resnets per level)
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias")
+ )
+
+ # Upsample (if exists)
+ if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict:
+ converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.conv.weight"
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.conv.bias"
+ )
+
+ # 2.4 Decoder output
+ converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma")
+ converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight")
+ converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias")
+
+ return converted
+
+
+def load_sharded_safetensors(dir: pathlib.Path):
+ file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
+ state_dict = {}
+ for path in file_paths:
+ state_dict.update(load_file(path))
+ return state_dict
+
+
+def load_original_transformer_state_dict(args):
+ if args.original_state_dict_repo_id is not None:
+ model_dir = snapshot_download(
+ args.original_state_dict_repo_id,
+ repo_type="model",
+ allow_patterns="transformer/" + args.transformer_type + "/*",
+ )
+ elif args.original_state_dict_folder is not None:
+ model_dir = pathlib.Path(args.original_state_dict_folder)
+ else:
+ raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`")
+ model_dir = pathlib.Path(model_dir)
+ model_dir = model_dir / "transformer" / args.transformer_type
+ return load_sharded_safetensors(model_dir)
+
+
+def load_original_vae_state_dict(args):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(
+ repo_id=args.original_state_dict_repo_id, filename="vae/diffusion_pytorch_model.safetensors"
+ )
+ elif args.original_state_dict_folder is not None:
+ model_dir = pathlib.Path(args.original_state_dict_folder)
+ ckpt_path = model_dir / "vae/diffusion_pytorch_model.safetensors"
+ else:
+ raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`")
+
+ original_state_dict = load_file(ckpt_path)
+ return original_state_dict
+
+
+def convert_transformer(args):
+ original_state_dict = load_original_transformer_state_dict(args)
+
+ config = TRANSFORMER_CONFIGS[args.transformer_type]
+ with init_empty_weights():
+ transformer = HunyuanVideo15Transformer3DModel(**config)
+ state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=transformer.config)
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
+
+ return transformer
+
+
+def convert_vae(args):
+ original_state_dict = load_original_vae_state_dict(args)
+ with init_empty_weights():
+ vae = AutoencoderKLHunyuanVideo15()
+ state_dict = convert_hunyuan_video_15_vae_checkpoint_to_diffusers(original_state_dict)
+ vae.load_state_dict(state_dict, strict=True, assign=True)
+ return vae
+
+
+def load_mllm():
+ print(" loading from Qwen/Qwen2.5-VL-7B-Instruct")
+ text_encoder = AutoModel.from_pretrained(
+ "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
+ )
+ if hasattr(text_encoder, "language_model"):
+ text_encoder = text_encoder.language_model
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right")
+ return text_encoder, tokenizer
+
+
+# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89
+def add_special_token(
+ tokenizer,
+ text_encoder,
+ add_color=True,
+ add_font=True,
+ multilingual=True,
+ color_ann_path="assets/color_idx.json",
+ font_ann_path="assets/multilingual_10-lang_idx.json",
+):
+ """
+ Add special tokens for color and font to tokenizer and text encoder.
+
+ Args:
+ tokenizer: Huggingface tokenizer.
+ text_encoder: Huggingface T5 encoder.
+ add_color (bool): Whether to add color tokens.
+ add_font (bool): Whether to add font tokens.
+ color_ann_path (str): Path to color annotation JSON.
+ font_ann_path (str): Path to font annotation JSON.
+ multilingual (bool): Whether to use multilingual font tokens.
+ """
+ with open(font_ann_path, "r") as f:
+ idx_font_dict = json.load(f)
+ with open(color_ann_path, "r") as f:
+ idx_color_dict = json.load(f)
+
+ if multilingual:
+ font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict]
+ else:
+ font_token = [f"" for i in range(len(idx_font_dict))]
+ color_token = [f"" for i in range(len(idx_color_dict))]
+ additional_special_tokens = []
+ if add_color:
+ additional_special_tokens += color_token
+ if add_font:
+ additional_special_tokens += font_token
+
+ tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
+ # Set mean_resizing=False to avoid PyTorch LAPACK dependency
+ text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
+
+
+def load_byt5(args):
+ """
+ Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format.
+ """
+
+ # 1. Load base tokenizer and encoder
+ tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
+
+ # Load as T5EncoderModel
+ encoder = T5EncoderModel.from_pretrained("google/byt5-small")
+
+ byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt")
+ color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json")
+ font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json")
+
+ # 2. Add special tokens
+ add_special_token(
+ tokenizer=tokenizer,
+ text_encoder=encoder,
+ add_color=True,
+ add_font=True,
+ color_ann_path=color_ann_path,
+ font_ann_path=font_ann_path,
+ multilingual=True,
+ )
+
+ # 3. Load Glyph-SDXL-v2 checkpoint
+ print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}")
+ checkpoint = torch.load(byt5_checkpoint_path, map_location="cpu")
+
+ # Handle different checkpoint formats
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ state_dict = checkpoint
+
+ # add 'encoder.' prefix to the keys
+ # Remove 'module.text_tower.encoder.' prefix if present
+ cleaned_state_dict = {}
+ for key, value in state_dict.items():
+ if key.startswith("module.text_tower.encoder."):
+ new_key = "encoder." + key[len("module.text_tower.encoder.") :]
+ cleaned_state_dict[new_key] = value
+ else:
+ new_key = "encoder." + key
+ cleaned_state_dict[new_key] = value
+
+ # 4. Load weights
+ missing_keys, unexpected_keys = encoder.load_state_dict(cleaned_state_dict, strict=False)
+ if unexpected_keys:
+ raise ValueError(f"Unexpected keys: {unexpected_keys}")
+ if "shared.weight" in missing_keys:
+ print(" Missing shared.weight as expected")
+ missing_keys.remove("shared.weight")
+ if missing_keys:
+ raise ValueError(f"Missing keys: {missing_keys}")
+
+ return encoder, tokenizer
+
+
+def load_siglip():
+ image_encoder = SiglipVisionModel.from_pretrained(
+ "black-forest-labs/FLUX.1-Redux-dev", subfolder="image_encoder", torch_dtype=torch.bfloat16
+ )
+ feature_extractor = SiglipImageProcessor.from_pretrained(
+ "black-forest-labs/FLUX.1-Redux-dev", subfolder="feature_extractor"
+ )
+ return image_encoder, feature_extractor
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model"
+ )
+ parser.add_argument(
+ "--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict"
+ )
+ parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved")
+ parser.add_argument("--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()))
+ parser.add_argument(
+ "--byt5_path",
+ type=str,
+ default=None,
+ help=(
+ "path to the downloaded byt5 checkpoint & assets. "
+ "Note: They use Glyph-SDXL-v2 as byt5 encoder. You can download from modelscope like: "
+ "`modelscope download --model AI-ModelScope/Glyph-SDXL-v2 --local_dir ./ckpts/text_encoder/Glyph-SDXL-v2` "
+ "or manually download following the instructions on "
+ "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/checkpoints-download.md. "
+ "The path should point to the Glyph-SDXL-v2 folder which should contain an `assets` folder and a `checkpoints` folder, "
+ "like: Glyph-SDXL-v2/assets/... and Glyph-SDXL-v2/checkpoints/byt5_model.pt"
+ ),
+ )
+ parser.add_argument("--save_pipeline", action="store_true")
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ if args.save_pipeline and args.byt5_path is None:
+ raise ValueError("Please provide --byt5_path when saving pipeline")
+
+ transformer = None
+
+ transformer = convert_transformer(args)
+ if not args.save_pipeline:
+ transformer.save_pretrained(args.output_path, safe_serialization=True)
+ else:
+ task_type = transformer.config.task_type
+
+ vae = convert_vae(args)
+
+ text_encoder, tokenizer = load_mllm()
+ text_encoder_2, tokenizer_2 = load_byt5(args)
+
+ flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"]
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
+
+ guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"]
+ guider = ClassifierFreeGuidance(guidance_scale=guidance_scale)
+
+ if task_type == "i2v":
+ image_encoder, feature_extractor = load_siglip()
+ pipeline = HunyuanVideo15ImageToVideoPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ guider=guider,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ elif task_type == "t2v":
+ pipeline = HunyuanVideo15Pipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ guider=guider,
+ scheduler=scheduler,
+ )
+ else:
+ raise ValueError(f"Task type {task_type} is not supported")
+
+ pipeline.save_pretrained(args.output_path, safe_serialization=True)
diff --git a/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
index 1c8383690890..5cef46c98983 100644
--- a/scripts/convert_hunyuandit_controlnet_to_diffusers.py
+++ b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
@@ -21,9 +21,9 @@ def main(args):
model_config = HunyuanDiT2DControlNetModel.load_config(
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
)
- model_config[
- "use_style_cond_and_image_meta_size"
- ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ model_config["use_style_cond_and_image_meta_size"] = (
+ args.use_style_cond_and_image_meta_size
+ ) ### version <= v1.1: True; version >= v1.2: False
print(model_config)
for key in state_dict:
diff --git a/scripts/convert_hunyuandit_to_diffusers.py b/scripts/convert_hunyuandit_to_diffusers.py
index da3af8333ee3..65fcccb22a1a 100644
--- a/scripts/convert_hunyuandit_to_diffusers.py
+++ b/scripts/convert_hunyuandit_to_diffusers.py
@@ -13,15 +13,14 @@ def main(args):
state_dict = state_dict[args.load_key]
except KeyError:
raise KeyError(
- f"{args.load_key} not found in the checkpoint."
- f"Please load from the following keys:{state_dict.keys()}"
+ f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
)
device = "cuda"
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
- model_config[
- "use_style_cond_and_image_meta_size"
- ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ model_config["use_style_cond_and_image_meta_size"] = (
+ args.use_style_cond_and_image_meta_size
+ ) ### version <= v1.1: True; version >= v1.2: False
# input_size -> sample_size, text_dim -> cross_attention_dim
for key in state_dict:
diff --git a/scripts/convert_k_upscaler_to_diffusers.py b/scripts/convert_k_upscaler_to_diffusers.py
index 62abedd73785..cff845ef8099 100644
--- a/scripts/convert_k_upscaler_to_diffusers.py
+++ b/scripts/convert_k_upscaler_to_diffusers.py
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
self_attention_prefix = f"{block_prefix}.{idx}"
- cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_prefix = f"{block_prefix}.{idx}"
cross_attention_index = 1 if not attention.add_self_attention else 2
idx = (
n * attention_idx + cross_attention_index
if block_type == "up"
else n * attention_idx + cross_attention_index + 1
)
- cross_attention_prefix = f"{block_prefix}.{idx }"
+ cross_attention_prefix = f"{block_prefix}.{idx}"
diffusers_checkpoint.update(
cross_attn_to_diffusers_checkpoint(
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
block_out_channels = original_config["channels"]
- assert (
- len(set(original_config["depths"])) == 1
- ), "UNet2DConditionModel currently do not support blocks with different number of layers"
+ assert len(set(original_config["depths"])) == 1, (
+ "UNet2DConditionModel currently do not support blocks with different number of layers"
+ )
layers_per_block = original_config["depths"][0]
class_labels_dim = original_config["mapping_cond_dim"]
diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py
index 2e966d5d110b..19e5602039e5 100644
--- a/scripts/convert_ltx_to_diffusers.py
+++ b/scripts/convert_ltx_to_diffusers.py
@@ -7,7 +7,15 @@
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
-from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
+from diffusers import (
+ AutoencoderKLLTXVideo,
+ FlowMatchEulerDiscreteScheduler,
+ LTXConditionPipeline,
+ LTXLatentUpsamplePipeline,
+ LTXPipeline,
+ LTXVideoTransformer3DModel,
+)
+from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -123,17 +131,10 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
state_dict[new_key] = state_dict.pop(old_key)
-def convert_transformer(
- ckpt_path: str,
- dtype: torch.dtype,
- version: str = "0.9.0",
-):
+def convert_transformer(ckpt_path: str, config, dtype: torch.dtype):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
- config = {}
- if version == "0.9.5":
- config["_use_causal_rope_fix"] = True
with init_empty_weights():
transformer = LTXVideoTransformer3DModel(**config)
@@ -180,8 +181,59 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
return vae
+def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype):
+ original_state_dict = get_state_dict(load_file(ckpt_path))
+
+ with init_empty_weights():
+ latent_upsampler = LTXLatentUpsamplerModel(**config)
+
+ latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
+ latent_upsampler.to(dtype)
+ return latent_upsampler
+
+
+def get_transformer_config(version: str) -> Dict[str, Any]:
+ if version == "0.9.7":
+ config = {
+ "in_channels": 128,
+ "out_channels": 128,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "num_attention_heads": 32,
+ "attention_head_dim": 128,
+ "cross_attention_dim": 4096,
+ "num_layers": 48,
+ "activation_fn": "gelu-approximate",
+ "qk_norm": "rms_norm_across_heads",
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "caption_channels": 4096,
+ "attention_bias": True,
+ "attention_out_bias": True,
+ }
+ else:
+ config = {
+ "in_channels": 128,
+ "out_channels": 128,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "num_attention_heads": 32,
+ "attention_head_dim": 64,
+ "cross_attention_dim": 2048,
+ "num_layers": 28,
+ "activation_fn": "gelu-approximate",
+ "qk_norm": "rms_norm_across_heads",
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "caption_channels": 4096,
+ "attention_bias": True,
+ "attention_out_bias": True,
+ }
+ return config
+
+
def get_vae_config(version: str) -> Dict[str, Any]:
- if version == "0.9.0":
+ if version in ["0.9.0"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -210,7 +262,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
"timestep_conditioning": False,
}
- elif version == "0.9.1":
+ elif version in ["0.9.1"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -240,7 +292,39 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
- elif version == "0.9.5":
+ elif version in ["0.9.5"]:
+ config = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 128,
+ "block_out_channels": (128, 256, 512, 1024, 2048),
+ "down_block_types": (
+ "LTXVideo095DownBlock3D",
+ "LTXVideo095DownBlock3D",
+ "LTXVideo095DownBlock3D",
+ "LTXVideo095DownBlock3D",
+ ),
+ "decoder_block_out_channels": (256, 512, 1024),
+ "layers_per_block": (4, 6, 6, 2, 2),
+ "decoder_layers_per_block": (5, 5, 5, 5),
+ "spatio_temporal_scaling": (True, True, True, True),
+ "decoder_spatio_temporal_scaling": (True, True, True),
+ "decoder_inject_noise": (False, False, False, False),
+ "downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
+ "upsample_residual": (True, True, True),
+ "upsample_factor": (2, 2, 2),
+ "timestep_conditioning": True,
+ "patch_size": 4,
+ "patch_size_t": 1,
+ "resnet_norm_eps": 1e-6,
+ "scaling_factor": 1.0,
+ "encoder_causal": True,
+ "decoder_causal": False,
+ "spatial_compression_ratio": 32,
+ "temporal_compression_ratio": 8,
+ }
+ VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
+ elif version in ["0.9.7"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -275,12 +359,42 @@ def get_vae_config(version: str) -> Dict[str, Any]:
return config
+def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
+ if version == "0.9.7":
+ config = {
+ "in_channels": 128,
+ "mid_channels": 512,
+ "num_blocks_per_stage": 4,
+ "dims": 3,
+ "spatial_upsample": True,
+ "temporal_upsample": False,
+ }
+ elif version == "0.9.8":
+ config = {
+ "in_channels": 128,
+ "mid_channels": 512,
+ "num_blocks_per_stage": 4,
+ "dims": 3,
+ "spatial_upsample": True,
+ "temporal_upsample": False,
+ }
+ else:
+ raise ValueError(f"Unsupported version: {version}")
+ return config
+
+
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
+ parser.add_argument(
+ "--spatial_latent_upsampler_path",
+ type=str,
+ default=None,
+ help="Path to original spatial latent upsampler checkpoint",
+ )
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
@@ -294,7 +408,11 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
- "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
+ "--version",
+ type=str,
+ default="0.9.0",
+ choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7", "0.9.8"],
+ help="Version of the LTX model",
)
return parser.parse_args()
@@ -320,11 +438,9 @@ def get_args():
variant = VARIANT_MAPPING[args.dtype]
output_path = Path(args.output_path)
- if args.save_pipeline:
- assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
-
if args.transformer_ckpt_path is not None:
- transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
+ config = get_transformer_config(args.version)
+ transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
@@ -336,6 +452,16 @@ def get_args():
if not args.save_pipeline:
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
+ if args.spatial_latent_upsampler_path is not None:
+ config = get_spatial_latent_upsampler_config(args.version)
+ latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler(
+ args.spatial_latent_upsampler_path, config, dtype
+ )
+ if not args.save_pipeline:
+ latent_upsampler.save_pretrained(
+ output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant
+ )
+
if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
@@ -348,7 +474,7 @@ def get_args():
for param in text_encoder.parameters():
param.data = param.data.contiguous()
- if args.version == "0.9.5":
+ if args.version in ["0.9.5", "0.9.7"]:
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
@@ -360,12 +486,40 @@ def get_args():
shift_terminal=0.1,
)
- pipe = LTXPipeline(
- scheduler=scheduler,
- vae=vae,
- text_encoder=text_encoder,
- tokenizer=tokenizer,
- transformer=transformer,
- )
-
- pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
+ if args.version in ["0.9.0", "0.9.1", "0.9.5"]:
+ pipe = LTXPipeline(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ )
+ pipe.save_pretrained(
+ output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB"
+ )
+ elif args.version in ["0.9.7"]:
+ pipe = LTXConditionPipeline(
+ scheduler=scheduler,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ )
+ pipe_upsample = LTXLatentUpsamplePipeline(
+ vae=vae,
+ latent_upsampler=latent_upsampler,
+ )
+ pipe.save_pretrained(
+ (output_path / "ltx_pipeline").as_posix(),
+ safe_serialization=True,
+ variant=variant,
+ max_shard_size="5GB",
+ )
+ pipe_upsample.save_pretrained(
+ (output_path / "ltx_upsample_pipeline").as_posix(),
+ safe_serialization=True,
+ variant=variant,
+ max_shard_size="5GB",
+ )
+ else:
+ raise ValueError(f"Unsupported version: {args.version}")
diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py
index 9727deeb6b0c..64e4f69eac17 100644
--- a/scripts/convert_mochi_to_diffusers.py
+++ b/scripts/convert_mochi_to_diffusers.py
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.0.weight"
+ f"blocks.0.{i + 1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.0.bias"
+ f"blocks.0.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.2.weight"
+ f"blocks.0.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.2.bias"
+ f"blocks.0.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.3.weight"
+ f"blocks.0.{i + 1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.3.bias"
+ f"blocks.0.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.5.weight"
+ f"blocks.0.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
- f"blocks.0.{i+1}.stack.5.bias"
+ f"blocks.0.{i + 1}.stack.5.bias"
)
# Convert up_blocks (MochiUpBlock3D)
@@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
for block in range(3):
for i in range(down_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.0.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.0.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.2.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.2.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.3.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.3.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.5.weight"
+ f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
)
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
- f"blocks.{block+1}.blocks.{i}.stack.5.bias"
+ f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
)
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
- f"blocks.{block+1}.proj.weight"
+ f"blocks.{block + 1}.proj.weight"
+ )
+ new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
+ f"blocks.{block + 1}.proj.bias"
)
- new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
@@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.0.weight"
+ f"layers.{i + 1}.stack.0.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.0.bias"
+ f"layers.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.2.weight"
+ f"layers.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.2.bias"
+ f"layers.{i + 1}.stack.2.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.3.weight"
+ f"layers.{i + 1}.stack.3.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.3.bias"
+ f"layers.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.5.weight"
+ f"layers.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+1}.stack.5.bias"
+ f"layers.{i + 1}.stack.5.bias"
)
# Convert down_blocks (MochiDownBlock3D)
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
for block in range(3):
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.0.weight"
+ f"layers.{block + 4}.layers.0.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.0.bias"
+ f"layers.{block + 4}.layers.0.bias"
)
for i in range(down_block_layers[block]):
# Convert resnets
- new_state_dict[
- f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
- ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
+ encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
+ )
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.0.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.2.weight"
+ f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.2.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
+ )
+ new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
+ encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
)
- new_state_dict[
- f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
- ] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.3.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.5.weight"
+ f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.stack.5.bias"
+ f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
)
# Convert attentions
- qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
+ qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
+ f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
)
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
# Convert resnets
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.0.weight"
+ f"layers.{i + 7}.stack.0.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.0.bias"
+ f"layers.{i + 7}.stack.0.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.2.weight"
+ f"layers.{i + 7}.stack.2.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.2.bias"
+ f"layers.{i + 7}.stack.2.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.3.weight"
+ f"layers.{i + 7}.stack.3.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.3.bias"
+ f"layers.{i + 7}.stack.3.bias"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.5.weight"
+ f"layers.{i + 7}.stack.5.weight"
)
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.stack.5.bias"
+ f"layers.{i + 7}.stack.5.bias"
)
# Convert attentions
- qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
+ qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.attn.out.weight"
+ f"layers.{i + 7}.attn_block.attn.out.weight"
)
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.attn.out.bias"
+ f"layers.{i + 7}.attn_block.attn.out.bias"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.norm.weight"
+ f"layers.{i + 7}.attn_block.norm.weight"
)
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
- f"layers.{i+7}.attn_block.norm.bias"
+ f"layers.{i + 7}.attn_block.norm.bias"
)
# Convert output layers
diff --git a/scripts/convert_original_audioldm2_to_diffusers.py b/scripts/convert_original_audioldm2_to_diffusers.py
index 1dc7d739ea76..2c0695ce5595 100644
--- a/scripts/convert_original_audioldm2_to_diffusers.py
+++ b/scripts/convert_original_audioldm2_to_diffusers.py
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py
index 4f8e4f8f9f80..44183f1aea29 100644
--- a/scripts/convert_original_audioldm_to_diffusers.py
+++ b/scripts/convert_original_audioldm_to_diffusers.py
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_original_musicldm_to_diffusers.py b/scripts/convert_original_musicldm_to_diffusers.py
index 61e5d16eea9e..00836fde2592 100644
--- a/scripts/convert_original_musicldm_to_diffusers.py
+++ b/scripts/convert_original_musicldm_to_diffusers.py
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
# replace sequential layers with list
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
- key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
+ key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
elif re.match(text_projection_pattern, key):
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
diff --git a/scripts/convert_ovis_image_to_diffusers.py b/scripts/convert_ovis_image_to_diffusers.py
new file mode 100644
index 000000000000..0d3d9cd44bf6
--- /dev/null
+++ b/scripts/convert_ovis_image_to_diffusers.py
@@ -0,0 +1,263 @@
+import argparse
+from contextlib import nullcontext
+
+import safetensors.torch
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+
+from diffusers import OvisImageTransformer2DModel
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+"""
+# Transformer
+
+python scripts/convert_ovis_image_to_diffusers.py \
+--original_state_dict_repo_id "AIDC-AI/Ovis-Image-7B" \
+--filename "ovis_image.safetensors"
+--output_path "ovis-image" \
+--transformer
+"""
+
+
+CTX = init_empty_weights if is_accelerate_available() else nullcontext
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
+parser.add_argument("--filename", default="ovis_image.safetensors", type=str)
+parser.add_argument("--checkpoint_path", default=None, type=str)
+parser.add_argument("--in_channels", type=int, default=64)
+parser.add_argument("--out_channels", type=int, default=None)
+parser.add_argument("--transformer", action="store_true")
+parser.add_argument("--output_path", type=str)
+parser.add_argument("--dtype", type=str, default="bf16")
+
+args = parser.parse_args()
+dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
+
+
+def load_original_checkpoint(args):
+ if args.original_state_dict_repo_id is not None:
+ ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
+ elif args.checkpoint_path is not None:
+ ckpt_path = args.checkpoint_path
+ else:
+ raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
+
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
+ return original_state_dict
+
+
+# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
+def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
+def convert_ovis_image_transformer_checkpoint_to_diffusers(
+ original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0
+):
+ converted_state_dict = {}
+
+ ## time_text_embed.timestep_embedder <- time_in
+ converted_state_dict["timestep_embedder.linear_1.weight"] = original_state_dict.pop("time_in.in_layer.weight")
+ converted_state_dict["timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.in_layer.bias")
+ converted_state_dict["timestep_embedder.linear_2.weight"] = original_state_dict.pop("time_in.out_layer.weight")
+ converted_state_dict["timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.out_layer.bias")
+
+ # context_embedder
+ converted_state_dict["context_embedder_norm.weight"] = original_state_dict.pop("semantic_txt_norm.weight")
+ converted_state_dict["context_embedder.weight"] = original_state_dict.pop("semantic_txt_in.weight")
+ converted_state_dict["context_embedder.bias"] = original_state_dict.pop("semantic_txt_in.bias")
+
+ # x_embedder
+ converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight")
+ converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias")
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ # norms.
+ ## norm1
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.lin.bias"
+ )
+ ## norm1_context
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.lin.bias"
+ )
+ # Q, K, V
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
+ )
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
+ )
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
+ )
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk_norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.query_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.norm.key_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.query_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.norm.key_norm.weight"
+ )
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = torch.cat(
+ [
+ original_state_dict.pop(f"double_blocks.{i}.img_mlp.up_proj.weight"),
+ original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.weight"),
+ ],
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = torch.cat(
+ [
+ original_state_dict.pop(f"double_blocks.{i}.img_mlp.up_proj.bias"),
+ original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.bias"),
+ ],
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.down_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.down_proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = torch.cat(
+ [
+ original_state_dict.pop(f"double_blocks.{i}.txt_mlp.up_proj.weight"),
+ original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.weight"),
+ ],
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = torch.cat(
+ [
+ original_state_dict.pop(f"double_blocks.{i}.txt_mlp.up_proj.bias"),
+ original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.bias"),
+ ],
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.down_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.down_proj.bias"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn.proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn.proj.bias"
+ )
+
+ # single transformer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.lin.bias"
+ )
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim * 2)
+ q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
+ # qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.query_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.norm.key_norm.weight"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.weight"
+ )
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.bias"
+ )
+
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ )
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
+ original_state_dict.pop("final_layer.adaLN_modulation.1.bias")
+ )
+
+ return converted_state_dict
+
+
+def main(args):
+ original_ckpt = load_original_checkpoint(args)
+
+ if args.transformer:
+ num_layers = 6
+ num_single_layers = 27
+ inner_dim = 3072
+ mlp_ratio = 4.0
+
+ converted_transformer_state_dict = convert_ovis_image_transformer_checkpoint_to_diffusers(
+ original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
+ )
+ transformer = OvisImageTransformer2DModel(in_channels=args.in_channels, out_channels=args.out_channels)
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ print("Saving Ovis-Image Transformer in Diffusers format.")
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_prx_to_diffusers.py b/scripts/convert_prx_to_diffusers.py
new file mode 100644
index 000000000000..d9bde2f34d56
--- /dev/null
+++ b/scripts/convert_prx_to_diffusers.py
@@ -0,0 +1,345 @@
+#!/usr/bin/env python3
+"""
+Script to convert PRX checkpoint from original codebase to diffusers format.
+"""
+
+import argparse
+import json
+import os
+import sys
+from dataclasses import asdict, dataclass
+from typing import Dict, Tuple
+
+import torch
+from safetensors.torch import save_file
+
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+from diffusers.pipelines.prx import PRXPipeline
+
+
+DEFAULT_RESOLUTION = 512
+
+
+@dataclass(frozen=True)
+class PRXBase:
+ context_in_dim: int = 2304
+ hidden_size: int = 1792
+ mlp_ratio: float = 3.5
+ num_heads: int = 28
+ depth: int = 16
+ axes_dim: Tuple[int, int] = (32, 32)
+ theta: int = 10_000
+ time_factor: float = 1000.0
+ time_max_period: int = 10_000
+
+
+@dataclass(frozen=True)
+class PRXFlux(PRXBase):
+ in_channels: int = 16
+ patch_size: int = 2
+
+
+@dataclass(frozen=True)
+class PRXDCAE(PRXBase):
+ in_channels: int = 32
+ patch_size: int = 1
+
+
+def build_config(vae_type: str) -> Tuple[dict, int]:
+ if vae_type == "flux":
+ cfg = PRXFlux()
+ elif vae_type == "dc-ae":
+ cfg = PRXDCAE()
+ else:
+ raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
+
+ config_dict = asdict(cfg)
+ config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
+ return config_dict
+
+
+def create_parameter_mapping(depth: int) -> dict:
+ """Create mapping from old parameter names to new diffusers names."""
+
+ # Key mappings for structural changes
+ mapping = {}
+
+ # Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
+ for i in range(depth):
+ # QKV projections moved to attention module
+ mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
+ mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
+
+ # QK norm moved to attention module and renamed to match Attention's qk_norm structure
+ mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
+ mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
+ mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
+ mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
+
+ # K norm for text tokens moved to attention module
+ mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
+ mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
+
+ # Attention output projection
+ mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
+
+ return mapping
+
+
+def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
+ """Convert old checkpoint parameters to new diffusers format."""
+
+ print("Converting checkpoint parameters...")
+
+ mapping = create_parameter_mapping(depth)
+ converted_state_dict = {}
+
+ for key, value in old_state_dict.items():
+ new_key = key
+
+ # Apply specific mappings if needed
+ if key in mapping:
+ new_key = mapping[key]
+ print(f" Mapped: {key} -> {new_key}")
+
+ converted_state_dict[new_key] = value
+
+ print(f"✓ Converted {len(converted_state_dict)} parameters")
+ return converted_state_dict
+
+
+def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
+ """Create and load PRXTransformer2DModel from old checkpoint."""
+
+ print(f"Loading checkpoint from: {checkpoint_path}")
+
+ # Load old checkpoint
+ if not os.path.exists(checkpoint_path):
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
+
+ old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
+
+ # Handle different checkpoint formats
+ if isinstance(old_checkpoint, dict):
+ if "model" in old_checkpoint:
+ state_dict = old_checkpoint["model"]
+ elif "state_dict" in old_checkpoint:
+ state_dict = old_checkpoint["state_dict"]
+ else:
+ state_dict = old_checkpoint
+ else:
+ state_dict = old_checkpoint
+
+ print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
+
+ # Convert parameter names if needed
+ model_depth = int(config.get("depth", 16))
+ converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
+
+ # Create transformer with config
+ print("Creating PRXTransformer2DModel...")
+ transformer = PRXTransformer2DModel(**config)
+
+ # Load state dict
+ print("Loading converted parameters...")
+ missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
+
+ if missing_keys:
+ print(f"⚠ Missing keys: {missing_keys}")
+ if unexpected_keys:
+ print(f"⚠ Unexpected keys: {unexpected_keys}")
+
+ if not missing_keys and not unexpected_keys:
+ print("✓ All parameters loaded successfully!")
+
+ return transformer
+
+
+def create_scheduler_config(output_path: str, shift: float):
+ """Create FlowMatchEulerDiscreteScheduler config."""
+
+ scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
+
+ scheduler_path = os.path.join(output_path, "scheduler")
+ os.makedirs(scheduler_path, exist_ok=True)
+
+ with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
+ json.dump(scheduler_config, f, indent=2)
+
+ print("✓ Created scheduler config")
+
+
+def download_and_save_vae(vae_type: str, output_path: str):
+ """Download and save VAE to local directory."""
+ from diffusers import AutoencoderDC, AutoencoderKL
+
+ vae_path = os.path.join(output_path, "vae")
+ os.makedirs(vae_path, exist_ok=True)
+
+ if vae_type == "flux":
+ print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
+ else: # dc-ae
+ print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
+ vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
+
+ vae.save_pretrained(vae_path)
+ print(f"✓ Saved VAE to {vae_path}")
+
+
+def download_and_save_text_encoder(output_path: str):
+ """Download and save T5Gemma text encoder and tokenizer."""
+ from transformers import GemmaTokenizerFast
+ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
+
+ text_encoder_path = os.path.join(output_path, "text_encoder")
+ tokenizer_path = os.path.join(output_path, "tokenizer")
+ os.makedirs(text_encoder_path, exist_ok=True)
+ os.makedirs(tokenizer_path, exist_ok=True)
+
+ print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
+ t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
+
+ # Extract and save only the encoder
+ t5gemma_encoder = t5gemma_model.encoder
+ t5gemma_encoder.save_pretrained(text_encoder_path)
+ print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
+
+ print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
+ tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
+ tokenizer.model_max_length = 256
+ tokenizer.save_pretrained(tokenizer_path)
+ print(f"✓ Saved tokenizer to {tokenizer_path}")
+
+
+def create_model_index(vae_type: str, default_image_size: int, output_path: str):
+ """Create model_index.json for the pipeline."""
+
+ if vae_type == "flux":
+ vae_class = "AutoencoderKL"
+ else: # dc-ae
+ vae_class = "AutoencoderDC"
+
+ model_index = {
+ "_class_name": "PRXPipeline",
+ "_diffusers_version": "0.31.0.dev0",
+ "_name_or_path": os.path.basename(output_path),
+ "default_sample_size": default_image_size,
+ "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
+ "text_encoder": ["prx", "T5GemmaEncoder"],
+ "tokenizer": ["transformers", "GemmaTokenizerFast"],
+ "transformer": ["diffusers", "PRXTransformer2DModel"],
+ "vae": ["diffusers", vae_class],
+ }
+
+ model_index_path = os.path.join(output_path, "model_index.json")
+ with open(model_index_path, "w") as f:
+ json.dump(model_index, f, indent=2)
+
+
+def main(args):
+ # Validate inputs
+ if not os.path.exists(args.checkpoint_path):
+ raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
+
+ config = build_config(args.vae_type)
+
+ # Create output directory
+ os.makedirs(args.output_path, exist_ok=True)
+ print(f"✓ Output directory: {args.output_path}")
+
+ # Create transformer from checkpoint
+ transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
+
+ # Save transformer
+ transformer_path = os.path.join(args.output_path, "transformer")
+ os.makedirs(transformer_path, exist_ok=True)
+
+ # Save config
+ with open(os.path.join(transformer_path, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ # Save model weights as safetensors
+ state_dict = transformer.state_dict()
+ save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
+ print(f"✓ Saved transformer to {transformer_path}")
+
+ # Create scheduler config
+ create_scheduler_config(args.output_path, args.shift)
+
+ download_and_save_vae(args.vae_type, args.output_path)
+ download_and_save_text_encoder(args.output_path)
+
+ # Create model_index.json
+ create_model_index(args.vae_type, args.resolution, args.output_path)
+
+ # Verify the pipeline can be loaded
+ try:
+ pipeline = PRXPipeline.from_pretrained(args.output_path)
+ print("Pipeline loaded successfully!")
+ print(f"Transformer: {type(pipeline.transformer).__name__}")
+ print(f"VAE: {type(pipeline.vae).__name__}")
+ print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
+ print(f"Scheduler: {type(pipeline.scheduler).__name__}")
+
+ # Display model info
+ num_params = sum(p.numel() for p in pipeline.transformer.parameters())
+ print(f"✓ Transformer parameters: {num_params:,}")
+
+ except Exception as e:
+ print(f"Pipeline verification failed: {e}")
+ return False
+
+ print("Conversion completed successfully!")
+ print(f"Converted pipeline saved to: {args.output_path}")
+ print(f"VAE type: {args.vae_type}")
+
+ return True
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
+
+ parser.add_argument(
+ "--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
+ )
+
+ parser.add_argument(
+ "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
+ )
+
+ parser.add_argument(
+ "--vae_type",
+ type=str,
+ choices=["flux", "dc-ae"],
+ required=True,
+ help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
+ )
+
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ choices=[256, 512, 1024],
+ default=DEFAULT_RESOLUTION,
+ help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
+ )
+
+ parser.add_argument(
+ "--shift",
+ type=float,
+ default=3.0,
+ help="Shift for the scheduler",
+ )
+
+ args = parser.parse_args()
+
+ try:
+ success = main(args)
+ if not success:
+ sys.exit(1)
+ except Exception as e:
+ print(f"Conversion failed: {e}")
+ import traceback
+
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/scripts/convert_sana_controlnet_to_diffusers.py b/scripts/convert_sana_controlnet_to_diffusers.py
new file mode 100644
index 000000000000..f7fcd7252576
--- /dev/null
+++ b/scripts/convert_sana_controlnet_to_diffusers.py
@@ -0,0 +1,216 @@
+#!/usr/bin/env python
+from __future__ import annotations
+
+import argparse
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+
+from diffusers import (
+ SanaControlNetModel,
+)
+from diffusers.models.model_loading_utils import load_model_dict_into_meta
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+
+def main(args):
+ file_path = args.orig_ckpt_path
+
+ all_state_dict = torch.load(file_path, weights_only=True)
+ state_dict = all_state_dict.pop("state_dict")
+ converted_state_dict = {}
+
+ # Patch embeddings.
+ converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
+
+ # Caption projection.
+ converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
+ converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
+ converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
+ converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
+
+ # AdaLN-single LN
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
+
+ # Shared norm.
+ converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
+ converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
+
+ # y norm
+ converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
+
+ # Positional embedding interpolation scale.
+ interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
+
+ # ControlNet Input Projection.
+ converted_state_dict["input_block.weight"] = state_dict.pop("controlnet.0.before_proj.weight")
+ converted_state_dict["input_block.bias"] = state_dict.pop("controlnet.0.before_proj.bias")
+
+ for depth in range(7):
+ # Transformer blocks.
+ converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.scale_shift_table"
+ )
+
+ # Linear Attention is all you need 🤘
+ # Self attention.
+ q, k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.attn.qkv.weight"), 3, dim=0)
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ # Projection.
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.attn.proj.bias"
+ )
+
+ # Feed-forward.
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.mlp.inverted_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.mlp.inverted_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.mlp.depth_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.mlp.depth_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.mlp.point_conv.conv.weight"
+ )
+
+ # Cross-attention.
+ q = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.weight")
+ q_bias = state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.q_linear.bias")
+ k, v = torch.chunk(state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.weight"), 2, dim=0)
+ k_bias, v_bias = torch.chunk(
+ state_dict.pop(f"controlnet.{depth}.copied_block.cross_attn.kv_linear.bias"), 2, dim=0
+ )
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.cross_attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
+ f"controlnet.{depth}.copied_block.cross_attn.proj.bias"
+ )
+
+ # ControlNet After Projection
+ converted_state_dict[f"controlnet_blocks.{depth}.weight"] = state_dict.pop(
+ f"controlnet.{depth}.after_proj.weight"
+ )
+ converted_state_dict[f"controlnet_blocks.{depth}.bias"] = state_dict.pop(f"controlnet.{depth}.after_proj.bias")
+
+ # ControlNet
+ with CTX():
+ controlnet = SanaControlNetModel(
+ num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
+ attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
+ num_layers=model_kwargs[args.model_type]["num_layers"],
+ num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
+ cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
+ cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
+ caption_channels=2304,
+ sample_size=args.image_size // 32,
+ interpolation_scale=interpolation_scale[args.image_size],
+ )
+
+ if is_accelerate_available():
+ load_model_dict_into_meta(controlnet, converted_state_dict)
+ else:
+ controlnet.load_state_dict(converted_state_dict, strict=True, assign=True)
+
+ num_model_params = sum(p.numel() for p in controlnet.parameters())
+ print(f"Total number of controlnet parameters: {num_model_params}")
+
+ controlnet = controlnet.to(weight_dtype)
+ controlnet.load_state_dict(converted_state_dict, strict=True)
+
+ print(f"Saving Sana ControlNet in Diffusers format in {args.dump_path}.")
+ controlnet.save_pretrained(args.dump_path)
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+VARIANT_MAPPING = {
+ "fp32": None,
+ "fp16": "fp16",
+ "bf16": "bf16",
+}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--orig_ckpt_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--image_size",
+ default=1024,
+ type=int,
+ choices=[512, 1024, 2048, 4096],
+ required=False,
+ help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
+ )
+ parser.add_argument(
+ "--model_type",
+ default="SanaMS_1600M_P1_ControlNet_D7",
+ type=str,
+ choices=["SanaMS_1600M_P1_ControlNet_D7", "SanaMS_600M_P1_ControlNet_D7"],
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
+ parser.add_argument("--dtype", default="fp16", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
+
+ args = parser.parse_args()
+
+ model_kwargs = {
+ "SanaMS_1600M_P1_ControlNet_D7": {
+ "num_attention_heads": 70,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "num_layers": 7,
+ },
+ "SanaMS_600M_P1_ControlNet_D7": {
+ "num_attention_heads": 36,
+ "attention_head_dim": 32,
+ "num_cross_attention_heads": 16,
+ "cross_attention_head_dim": 72,
+ "cross_attention_dim": 1152,
+ "num_layers": 7,
+ },
+ }
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ weight_dtype = DTYPE_MAPPING[args.dtype]
+ variant = VARIANT_MAPPING[args.dtype]
+
+ main(args)
diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py
index 1c40072177c6..833e64ba75e7 100644
--- a/scripts/convert_sana_to_diffusers.py
+++ b/scripts/convert_sana_to_diffusers.py
@@ -20,7 +20,7 @@
SanaTransformer2DModel,
SCMScheduler,
)
-from diffusers.models.modeling_utils import load_model_dict_into_meta
+from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
@@ -394,7 +394,7 @@ def main(args):
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
- parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
+ parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
args = parser.parse_args()
diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py
new file mode 100644
index 000000000000..a939a06cbd46
--- /dev/null
+++ b/scripts/convert_sana_video_to_diffusers.py
@@ -0,0 +1,327 @@
+#!/usr/bin/env python
+from __future__ import annotations
+
+import argparse
+import os
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download, snapshot_download
+from termcolor import colored
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from diffusers import (
+ AutoencoderKLWan,
+ DPMSolverMultistepScheduler,
+ FlowMatchEulerDiscreteScheduler,
+ SanaVideoPipeline,
+ SanaVideoTransformer3DModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
+# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
+
+
+def main(args):
+ cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
+
+ if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
+ ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
+ snapshot_download(
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
+ cache_dir=cache_dir_path,
+ repo_type="model",
+ )
+ file_path = hf_hub_download(
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
+ filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
+ cache_dir=cache_dir_path,
+ repo_type="model",
+ )
+ else:
+ file_path = args.orig_ckpt_path
+
+ print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
+ all_state_dict = torch.load(file_path, weights_only=True)
+ state_dict = all_state_dict.pop("state_dict")
+ converted_state_dict = {}
+
+ # Patch embeddings.
+ converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
+
+ # Caption projection.
+ converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
+ converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
+ converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
+ converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
+
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
+
+ # Shared norm.
+ converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
+ converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
+
+ # y norm
+ converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
+
+ # scheduler
+ flow_shift = 8.0
+ if args.task == "i2v":
+ assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
+
+ # model config
+ layer_num = 20
+ # Positional embedding interpolation scale.
+ qk_norm = True
+
+ # sample size
+ if args.video_size == 480:
+ sample_size = 30 # Wan-VAE: 8xp2 downsample factor
+ patch_size = (1, 2, 2)
+ elif args.video_size == 720:
+ sample_size = 22 # Wan-VAE: 32xp1 downsample factor
+ patch_size = (1, 1, 1)
+ else:
+ raise ValueError(f"Video size {args.video_size} is not supported.")
+
+ for depth in range(layer_num):
+ # Transformer blocks.
+ converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
+ f"blocks.{depth}.scale_shift_table"
+ )
+
+ # Linear Attention is all you need 🤘
+ # Self attention.
+ q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ if qk_norm is not None:
+ # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.k_norm.weight"
+ )
+ # Projection.
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.bias"
+ )
+
+ # Feed-forward.
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.inverted_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.inverted_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.depth_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.depth_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.point_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.t_conv.weight"
+ )
+
+ # Cross-attention.
+ q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
+ q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
+ k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
+ if qk_norm is not None:
+ # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.k_norm.weight"
+ )
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.bias"
+ )
+
+ # Final block.
+ converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
+
+ # Transformer
+ with CTX():
+ transformer_kwargs = {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 20,
+ "attention_head_dim": 112,
+ "num_layers": 20,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "caption_channels": 2304,
+ "mlp_ratio": 3.0,
+ "attention_bias": False,
+ "sample_size": sample_size,
+ "patch_size": patch_size,
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 1024,
+ }
+
+ transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
+
+ transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
+
+ try:
+ state_dict.pop("y_embedder.y_embedding")
+ state_dict.pop("pos_embed")
+ state_dict.pop("logvar_linear.weight")
+ state_dict.pop("logvar_linear.bias")
+ except KeyError:
+ print("y_embedder.y_embedding or pos_embed not found in the state_dict")
+
+ assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
+
+ num_model_params = sum(p.numel() for p in transformer.parameters())
+ print(f"Total number of transformer parameters: {num_model_params}")
+
+ transformer = transformer.to(weight_dtype)
+
+ if not args.save_full_pipeline:
+ print(
+ colored(
+ f"Only saving transformer model of {args.model_type}. "
+ f"Set --save_full_pipeline to save the whole Pipeline",
+ "green",
+ attrs=["bold"],
+ )
+ )
+ transformer.save_pretrained(
+ os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
+ )
+ else:
+ print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
+ # VAE
+ vae = AutoencoderKLWan.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
+ )
+
+ # Text Encoder
+ text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
+ tokenizer.padding_side = "right"
+ text_encoder = AutoModelForCausalLM.from_pretrained(
+ text_encoder_model_path, torch_dtype=torch.bfloat16
+ ).get_decoder()
+
+ # Choose the appropriate pipeline and scheduler based on model type
+ # Original Sana scheduler
+ if args.scheduler_type == "flow-dpm_solver":
+ scheduler = DPMSolverMultistepScheduler(
+ flow_shift=flow_shift,
+ use_flow_sigmas=True,
+ prediction_type="flow_prediction",
+ )
+ elif args.scheduler_type == "flow-euler":
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
+ elif args.scheduler_type == "uni-pc":
+ scheduler = UniPCMultistepScheduler(
+ prediction_type="flow_prediction",
+ use_flow_sigmas=True,
+ num_train_timesteps=1000,
+ flow_shift=flow_shift,
+ )
+ else:
+ raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
+
+ pipe = SanaVideoPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+
+ pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--video_size",
+ default=480,
+ type=int,
+ choices=[480, 720],
+ required=False,
+ help="Video size of pretrained model, 480 or 720.",
+ )
+ parser.add_argument(
+ "--model_type",
+ default="SanaVideo",
+ type=str,
+ choices=[
+ "SanaVideo",
+ ],
+ )
+ parser.add_argument(
+ "--scheduler_type",
+ default="flow-dpm_solver",
+ type=str,
+ choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
+ help="Scheduler type to use.",
+ )
+ parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
+ parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
+ parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
+
+ args = parser.parse_args()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ weight_dtype = DTYPE_MAPPING[args.dtype]
+
+ main(args)
diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py
index 0a3569efeab0..83cb436e6e32 100644
--- a/scripts/convert_sd3_to_diffusers.py
+++ b/scripts/convert_sd3_to_diffusers.py
@@ -7,7 +7,7 @@
from diffusers import AutoencoderKL, SD3Transformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
-from diffusers.models.modeling_utils import load_model_dict_into_meta
+from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
diff --git a/scripts/convert_shap_e_to_diffusers.py b/scripts/convert_shap_e_to_diffusers.py
index b903b4ee8a7f..ac6543667af9 100644
--- a/scripts/convert_shap_e_to_diffusers.py
+++ b/scripts/convert_shap_e_to_diffusers.py
@@ -984,7 +984,7 @@ def renderer(*, args, checkpoint_map_location):
return renderer_model
-# prior model will expect clip_mean and clip_std, whic are missing from the state_dict
+# prior model will expect clip_mean and clip_std, which are missing from the state_dict
PRIOR_EXPECTED_MISSING_KEYS = ["clip_mean", "clip_std"]
diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py
new file mode 100644
index 000000000000..3bc3c435685b
--- /dev/null
+++ b/scripts/convert_skyreelsv2_to_diffusers.py
@@ -0,0 +1,637 @@
+import argparse
+import os
+import pathlib
+from typing import Any, Dict
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file
+from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2DiffusionForcingPipeline,
+ SkyReelsV2ImageToVideoPipeline,
+ SkyReelsV2Pipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ "fps_projection.0": "fps_projection.net.0.proj",
+ "fps_projection.2": "fps_projection.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ # For the I2V model
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # for the FLF2V model
+ "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
+ # Add attention component mappings
+ "self_attn.q": "attn1.to_q",
+ "self_attn.k": "attn1.to_k",
+ "self_attn.v": "attn1.to_v",
+ "self_attn.o": "attn1.to_out.0",
+ "self_attn.norm_q": "attn1.norm_q",
+ "self_attn.norm_k": "attn1.norm_k",
+ "cross_attn.q": "attn2.to_q",
+ "cross_attn.k": "attn2.to_k",
+ "cross_attn.v": "attn2.to_v",
+ "cross_attn.o": "attn2.to_out.0",
+ "cross_attn.norm_q": "attn2.norm_q",
+ "cross_attn.norm_k": "attn2.norm_k",
+ "attn2.to_k_img": "attn2.add_k_proj",
+ "attn2.to_v_img": "attn2.add_v_proj",
+ "attn2.norm_k_img": "attn2.norm_added_k",
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+
+
+def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def load_sharded_safetensors(dir: pathlib.Path):
+ if "720P" in str(dir):
+ file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
+ else:
+ file_paths = list(dir.glob("model*.safetensors"))
+ state_dict = {}
+ for path in file_paths:
+ state_dict.update(load_file(path))
+ return state_dict
+
+
+def get_transformer_config(model_type: str) -> Dict[str, Any]:
+ if model_type == "SkyReels-V2-DF-1.3B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-DF-1.3B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 12,
+ "inject_sample_info": True,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-DF-14B-720P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-DF-14B-720P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-DF-14B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-DF-14B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-T2V-14B-720P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-T2V-14B-720P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-T2V-14B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-T2V-14B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-I2V-1.3B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 1536,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 12,
+ "inject_sample_info": False,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ },
+ }
+ elif model_type == "SkyReels-V2-I2V-14B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-14B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ },
+ }
+ elif model_type == "SkyReels-V2-I2V-14B-720P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-14B-720P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ },
+ }
+ elif model_type == "SkyReels-V2-FLF2V-1.3B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 1536,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 12,
+ "inject_sample_info": False,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ "pos_embed_seq_len": 514,
+ },
+ }
+ elif model_type == "SkyReels-V2-FLF2V-14B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-14B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ "pos_embed_seq_len": 514,
+ },
+ }
+ elif model_type == "SkyReels-V2-FLF2V-14B-720P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-14B-720P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ "pos_embed_seq_len": 514,
+ },
+ }
+ return config
+
+
+def convert_transformer(model_type: str):
+ config = get_transformer_config(model_type)
+ diffusers_config = config["diffusers_config"]
+ model_id = config["model_id"]
+
+ if "1.3B" in model_type:
+ original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors"))
+ else:
+ os.makedirs(model_type, exist_ok=True)
+ model_dir = pathlib.Path(model_type)
+ if "720P" in model_type:
+ top_shard = 7 if "I2V" in model_type else 6
+ zeros = "0" * (4 if "I2V" or "T2V" in model_type else 3)
+ model_name = "diffusion_pytorch_model"
+ elif "540P" in model_type:
+ top_shard = 14 if "I2V" in model_type else 12
+ model_name = "model"
+
+ for i in range(1, top_shard + 1):
+ shard_path = f"{model_name}-{i:05d}-of-{zeros}{top_shard}.safetensors"
+ hf_hub_download(model_id, shard_path, local_dir=model_dir)
+ original_state_dict = load_sharded_safetensors(model_dir)
+
+ with init_empty_weights():
+ transformer = SkyReelsV2Transformer3DModel.from_config(diffusers_config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ if "FLF2V" in model_type:
+ if (
+ hasattr(transformer.condition_embedder, "image_embedder")
+ and hasattr(transformer.condition_embedder.image_embedder, "pos_embed")
+ and transformer.condition_embedder.image_embedder.pos_embed is not None
+ ):
+ pos_embed_shape = transformer.condition_embedder.image_embedder.pos_embed.shape
+ original_state_dict["condition_embedder.image_embedder.pos_embed"] = torch.zeros(pos_embed_shape)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def convert_vae():
+ vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
+ old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
+ new_state_dict = {}
+
+ # Create mappings for specific components
+ middle_key_mapping = {
+ # Encoder middle block
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ # Decoder middle block
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ }
+
+ # Create a mapping for attention blocks
+ attention_mapping = {
+ # Encoder middle attention
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
+ # Decoder middle attention
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
+ }
+
+ # Create a mapping for the head components
+ head_mapping = {
+ # Encoder head
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
+ "encoder.head.2.bias": "encoder.conv_out.bias",
+ "encoder.head.2.weight": "encoder.conv_out.weight",
+ # Decoder head
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
+ "decoder.head.2.bias": "decoder.conv_out.bias",
+ "decoder.head.2.weight": "decoder.conv_out.weight",
+ }
+
+ # Create a mapping for the quant components
+ quant_mapping = {
+ "conv1.weight": "quant_conv.weight",
+ "conv1.bias": "quant_conv.bias",
+ "conv2.weight": "post_quant_conv.weight",
+ "conv2.bias": "post_quant_conv.bias",
+ }
+
+ # Process each key in the state dict
+ for key, value in old_state_dict.items():
+ # Handle middle block keys using the mapping
+ if key in middle_key_mapping:
+ new_key = middle_key_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle attention blocks using the mapping
+ elif key in attention_mapping:
+ new_key = attention_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle head keys using the mapping
+ elif key in head_mapping:
+ new_key = head_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle quant keys using the mapping
+ elif key in quant_mapping:
+ new_key = quant_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle encoder conv1
+ elif key == "encoder.conv1.weight":
+ new_state_dict["encoder.conv_in.weight"] = value
+ elif key == "encoder.conv1.bias":
+ new_state_dict["encoder.conv_in.bias"] = value
+ # Handle decoder conv1
+ elif key == "decoder.conv1.weight":
+ new_state_dict["decoder.conv_in.weight"] = value
+ elif key == "decoder.conv1.bias":
+ new_state_dict["decoder.conv_in.bias"] = value
+ # Handle encoder downsamples
+ elif key.startswith("encoder.downsamples."):
+ # Convert to down_blocks
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
+
+ # Convert residual block naming but keep the original structure
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+
+ new_state_dict[new_key] = value
+
+ # Handle decoder upsamples
+ elif key.startswith("decoder.upsamples."):
+ # Convert to up_blocks
+ parts = key.split(".")
+ block_idx = int(parts[2])
+
+ # Group residual blocks
+ if "residual" in key:
+ if block_idx in [0, 1, 2]:
+ new_block_idx = 0
+ resnet_idx = block_idx
+ elif block_idx in [4, 5, 6]:
+ new_block_idx = 1
+ resnet_idx = block_idx - 4
+ elif block_idx in [8, 9, 10]:
+ new_block_idx = 2
+ resnet_idx = block_idx - 8
+ elif block_idx in [12, 13, 14]:
+ new_block_idx = 3
+ resnet_idx = block_idx - 12
+ else:
+ # Keep as is for other blocks
+ new_state_dict[key] = value
+ continue
+
+ # Convert residual block naming
+ if ".residual.0.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
+ elif ".residual.2.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
+ elif ".residual.2.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
+ elif ".residual.3.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
+ elif ".residual.6.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
+ elif ".residual.6.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
+ else:
+ new_key = key
+
+ new_state_dict[new_key] = value
+
+ # Handle shortcut connections
+ elif ".shortcut." in key:
+ if block_idx == 4:
+ new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
+ new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
+
+ new_state_dict[new_key] = value
+
+ # Handle upsamplers
+ elif ".resample." in key or ".time_conv." in key:
+ if block_idx == 3:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
+ elif block_idx == 7:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
+ elif block_idx == 11:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+
+ new_state_dict[new_key] = value
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_state_dict[new_key] = value
+ else:
+ # Keep other keys unchanged
+ new_state_dict[key] = value
+
+ with init_empty_weights():
+ vae = AutoencoderKLWan()
+ vae.load_state_dict(new_state_dict, strict=True, assign=True)
+ return vae
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default=None)
+ parser.add_argument("--output_path", type=str, required=True)
+ parser.add_argument("--dtype", default="fp32")
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ dtype = DTYPE_MAPPING[args.dtype]
+
+ transformer = convert_transformer(args.model_type).to(dtype=dtype)
+ vae = convert_vae()
+ text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ scheduler = UniPCMultistepScheduler(
+ prediction_type="flow_prediction",
+ num_train_timesteps=1000,
+ use_flow_sigmas=True,
+ )
+
+ if "I2V" in args.model_type or "FLF2V" in args.model_type:
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ pipe = SkyReelsV2ImageToVideoPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
+ elif "T2V" in args.model_type:
+ pipe = SkyReelsV2Pipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ elif "DF" in args.model_type:
+ pipe = SkyReelsV2DiffusionForcingPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+
+ pipe.save_pretrained(
+ args.output_path,
+ safe_serialization=True,
+ max_shard_size="5GB",
+ # push_to_hub=True,
+ # repo_id=f"/{args.model_type}-Diffusers",
+ )
diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py
index a0f9d0f87d90..c3479fc6b2bb 100644
--- a/scripts/convert_stable_audio.py
+++ b/scripts/convert_stable_audio.py
@@ -1,4 +1,4 @@
-# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
+# Run this script to convert the Stable Audio model weights to a diffusers pipeline.
import argparse
import json
import os
@@ -18,7 +18,7 @@
StableAudioPipeline,
StableAudioProjectionModel,
)
-from diffusers.models.modeling_utils import load_model_dict_into_meta
+from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.utils import is_accelerate_available
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
# get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
- new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
+ new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
if "encoder" in new_key:
for i in range(3):
- new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
- new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
- new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
else:
for i in range(2, 5):
- new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
- new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
- new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
+ new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
if idx == num_autoencoder_layers + 1:
- new_key = new_key.replace(f"block.{idx-1}", "snake1")
+ new_key = new_key.replace(f"block.{idx - 1}", "snake1")
elif idx == num_autoencoder_layers + 2:
- new_key = new_key.replace(f"block.{idx-1}", "conv2")
+ new_key = new_key.replace(f"block.{idx - 1}", "conv2")
else:
new_key = new_key
diff --git a/scripts/convert_stable_cascade.py b/scripts/convert_stable_cascade.py
index ce10970b0b6a..97ed18d9b4d4 100644
--- a/scripts/convert_stable_cascade.py
+++ b/scripts/convert_stable_cascade.py
@@ -20,7 +20,7 @@
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
-from diffusers.models.modeling_utils import load_model_dict_into_meta
+from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
diff --git a/scripts/convert_stable_cascade_lite.py b/scripts/convert_stable_cascade_lite.py
index ddccaa3b2e8a..5f4804e30f74 100644
--- a/scripts/convert_stable_cascade_lite.py
+++ b/scripts/convert_stable_cascade_lite.py
@@ -20,7 +20,7 @@
)
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
from diffusers.models import StableCascadeUNet
-from diffusers.models.modeling_utils import load_model_dict_into_meta
+from diffusers.models.model_loading_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils import is_accelerate_available
diff --git a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
index 96546b62b1d6..1fb9f8d35bae 100644
--- a/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
+++ b/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py
index 3243ce294b26..e46410ccb3bd 100644
--- a/scripts/convert_svd_to_diffusers.py
+++ b/scripts/convert_svd_to_diffusers.py
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
- new_checkpoint[
- f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
- ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
+ new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
+ unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
+ )
if len(attentions):
paths = renew_attention_paths(attentions)
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
- new_checkpoint[
- f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
- ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
+ new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
+ unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
+ )
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
diff --git a/scripts/convert_vae_pt_to_diffusers.py b/scripts/convert_vae_pt_to_diffusers.py
index a4f967c94fa6..8c7dc71ddfd8 100644
--- a/scripts/convert_vae_pt_to_diffusers.py
+++ b/scripts/convert_vae_pt_to_diffusers.py
@@ -13,6 +13,7 @@
renew_vae_attention_paths,
renew_vae_resnet_paths,
)
+from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
@@ -52,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
}
for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+ resnets = [
+ key
+ for key in down_blocks[i]
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
+ ]
+ attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
@@ -66,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ paths = renew_vae_attention_paths(attentions)
+ meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
@@ -84,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ key
+ for key in up_blocks[block_id]
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
]
+ attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
@@ -99,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ paths = renew_vae_attention_paths(attentions)
+ meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
@@ -122,7 +139,8 @@ def vae_pt_to_vae_diffuser(
):
# Only support V1
r = requests.get(
- " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
+ " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
+ timeout=DIFFUSERS_REQUEST_TIMEOUT,
)
io_obj = io.BytesIO(r.content)
diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py
index 7da6b4094986..fe62d18faff0 100644
--- a/scripts/convert_vq_diffusion_to_diffusers.py
+++ b/scripts/convert_vq_diffusion_to_diffusers.py
@@ -51,9 +51,9 @@
def vqvae_model_from_original_config(original_config):
- assert (
- original_config["target"] in PORTED_VQVAES
- ), f"{original_config['target']} has not yet been ported to diffusers."
+ assert original_config["target"] in PORTED_VQVAES, (
+ f"{original_config['target']} has not yet been ported to diffusers."
+ )
original_config = original_config["params"]
@@ -464,15 +464,15 @@ def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_p
def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config
):
- assert (
- original_diffusion_config["target"] in PORTED_DIFFUSIONS
- ), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
- assert (
- original_transformer_config["target"] in PORTED_TRANSFORMERS
- ), f"{original_transformer_config['target']} has not yet been ported to diffusers."
- assert (
- original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
- ), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
+ assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
+ f"{original_diffusion_config['target']} has not yet been ported to diffusers."
+ )
+ assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
+ f"{original_transformer_config['target']} has not yet been ported to diffusers."
+ )
+ assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
+ f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
+ )
original_diffusion_config = original_diffusion_config["params"]
original_transformer_config = original_transformer_config["params"]
diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py
index 0b2fa872487e..06f87409262a 100644
--- a/scripts/convert_wan_to_diffusers.py
+++ b/scripts/convert_wan_to_diffusers.py
@@ -1,19 +1,30 @@
import argparse
import pathlib
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
-from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
+from transformers import (
+ AutoProcessor,
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionModel,
+ CLIPVisionModelWithProjection,
+ UMT5EncoderModel,
+)
from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
+ WanAnimatePipeline,
+ WanAnimateTransformer3DModel,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
+ WanVACEPipeline,
+ WanVACETransformer3DModel,
)
@@ -39,9 +50,267 @@
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # for the FLF2V model
+ "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
+ # Add attention component mappings
+ "self_attn.q": "attn1.to_q",
+ "self_attn.k": "attn1.to_k",
+ "self_attn.v": "attn1.to_v",
+ "self_attn.o": "attn1.to_out.0",
+ "self_attn.norm_q": "attn1.norm_q",
+ "self_attn.norm_k": "attn1.norm_k",
+ "cross_attn.q": "attn2.to_q",
+ "cross_attn.k": "attn2.to_k",
+ "cross_attn.v": "attn2.to_v",
+ "cross_attn.o": "attn2.to_out.0",
+ "cross_attn.norm_q": "attn2.norm_q",
+ "cross_attn.norm_k": "attn2.norm_k",
+ "attn2.to_k_img": "attn2.add_k_proj",
+ "attn2.to_v_img": "attn2.add_v_proj",
+ "attn2.norm_k_img": "attn2.norm_added_k",
+}
+
+VACE_TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ # # For the I2V model
+ # "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ # "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ # "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ # "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # # for the FLF2V model
+ # "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
+ # Add attention component mappings
+ "self_attn.q": "attn1.to_q",
+ "self_attn.k": "attn1.to_k",
+ "self_attn.v": "attn1.to_v",
+ "self_attn.o": "attn1.to_out.0",
+ "self_attn.norm_q": "attn1.norm_q",
+ "self_attn.norm_k": "attn1.norm_k",
+ "cross_attn.q": "attn2.to_q",
+ "cross_attn.k": "attn2.to_k",
+ "cross_attn.v": "attn2.to_v",
+ "cross_attn.o": "attn2.to_out.0",
+ "cross_attn.norm_q": "attn2.norm_q",
+ "cross_attn.norm_k": "attn2.norm_k",
+ "attn2.to_k_img": "attn2.add_k_proj",
+ "attn2.to_v_img": "attn2.add_v_proj",
+ "attn2.norm_k_img": "attn2.norm_added_k",
+ "before_proj": "proj_in",
+ "after_proj": "proj_out",
+}
+
+ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # Add attention component mappings
+ "self_attn.q": "attn1.to_q",
+ "self_attn.k": "attn1.to_k",
+ "self_attn.v": "attn1.to_v",
+ "self_attn.o": "attn1.to_out.0",
+ "self_attn.norm_q": "attn1.norm_q",
+ "self_attn.norm_k": "attn1.norm_k",
+ "cross_attn.q": "attn2.to_q",
+ "cross_attn.k": "attn2.to_k",
+ "cross_attn.v": "attn2.to_v",
+ "cross_attn.o": "attn2.to_out.0",
+ "cross_attn.norm_q": "attn2.norm_q",
+ "cross_attn.norm_k": "attn2.norm_k",
+ "cross_attn.k_img": "attn2.to_k_img",
+ "cross_attn.v_img": "attn2.to_v_img",
+ "cross_attn.norm_k_img": "attn2.norm_k_img",
+ # After cross_attn -> attn2 rename, we need to rename the img keys
+ "attn2.to_k_img": "attn2.add_k_proj",
+ "attn2.to_v_img": "attn2.add_v_proj",
+ "attn2.norm_k_img": "attn2.norm_added_k",
+ # Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
+ # Motion encoder mappings
+ # The name mapping is complicated for the convolutional part so we handle that in its own function
+ "motion_encoder.enc.fc": "motion_encoder.motion_network",
+ "motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
+ # Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
+ "face_encoder.conv1_local.conv": "face_encoder.conv1_local",
+ "face_encoder.conv2.conv": "face_encoder.conv2",
+ "face_encoder.conv3.conv": "face_encoder.conv3",
+ # Face adapter mappings are handled in a separate function
}
+
+# TODO: Verify this and simplify if possible.
+def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
+ """
+ Convert all motion encoder weights for Animate model.
+
+ In the original model:
+ - All Linear layers in fc use EqualLinear
+ - All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
+ - Blur kernels are stored as buffers in Sequential modules
+ - ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
+
+ Conversion strategy:
+ 1. Drop .kernel buffers (blur kernels)
+ 2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
+ """
+ # Skip if not a weight, bias, or kernel
+ if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
+ return
+
+ # Handle Blur kernel buffers from original implementation.
+ # After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
+ # Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
+ if ".kernel" in key and "motion_encoder" in key:
+ # Remove unexpected blur kernel buffers to avoid strict load errors
+ state_dict.pop(key, None)
+ return
+
+ # Rename Sequential indices to named components in ConvLayer and ResBlock
+ if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
+ parts = key.split(".")
+
+ # Find the sequential index (digit) after convs or after conv1/conv2/skip
+ # Examples:
+ # - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
+ # - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
+ # - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
+ # - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
+ # - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
+ # - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
+ # - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
+ # - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
+ # - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
+ # - enc.net_app.convs.8 -> conv_out (final conv layer)
+
+ convs_idx = parts.index("convs") if "convs" in parts else -1
+ if convs_idx >= 0 and len(parts) - convs_idx >= 2:
+ bias = False
+ # The nn.Sequential index will always follow convs
+ sequential_idx = int(parts[convs_idx + 1])
+ if sequential_idx == 0:
+ if key.endswith(".weight"):
+ new_key = "motion_encoder.conv_in.weight"
+ elif key.endswith(".bias"):
+ new_key = "motion_encoder.conv_in.act_fn.bias"
+ bias = True
+ elif sequential_idx == final_conv_idx:
+ if key.endswith(".weight"):
+ new_key = "motion_encoder.conv_out.weight"
+ else:
+ # Intermediate .convs. layers, which get mapped to .res_blocks.
+ prefix = "motion_encoder.res_blocks."
+
+ layer_name = parts[convs_idx + 2]
+ if layer_name == "skip":
+ layer_name = "conv_skip"
+
+ if key.endswith(".weight"):
+ param_name = "weight"
+ elif key.endswith(".bias"):
+ param_name = "act_fn.bias"
+ bias = True
+
+ suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
+ suffix = ".".join(suffix_parts)
+ new_key = prefix + suffix
+
+ param = state_dict.pop(key)
+ if bias:
+ param = param.squeeze()
+ state_dict[new_key] = param
+ return
+ return
+ return
+
+
+def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
+ """
+ Convert face adapter weights for the Animate model.
+
+ The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
+ """
+ # Skip if not a weight or bias
+ if ".weight" not in key and ".bias" not in key:
+ return
+
+ prefix = "face_adapter."
+ if ".fuser_blocks." in key:
+ parts = key.split(".")
+
+ module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
+ if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
+ block_idx = parts[module_list_idx + 1]
+ layer_name = parts[module_list_idx + 2]
+ param_name = parts[module_list_idx + 3]
+
+ if layer_name == "linear1_kv":
+ layer_name_k = "to_k"
+ layer_name_v = "to_v"
+
+ suffix_k = ".".join([block_idx, layer_name_k, param_name])
+ suffix_v = ".".join([block_idx, layer_name_v, param_name])
+ new_key_k = prefix + suffix_k
+ new_key_v = prefix + suffix_v
+
+ kv_proj = state_dict.pop(key)
+ k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
+ state_dict[new_key_k] = k_proj
+ state_dict[new_key_v] = v_proj
+ return
+ else:
+ if layer_name == "q_norm":
+ new_layer_name = "norm_q"
+ elif layer_name == "k_norm":
+ new_layer_name = "norm_k"
+ elif layer_name == "linear1_q":
+ new_layer_name = "to_q"
+ elif layer_name == "linear2":
+ new_layer_name = "to_out"
+
+ suffix_parts = [block_idx, new_layer_name, param_name]
+ suffix = ".".join(suffix_parts)
+ new_key = prefix + suffix
+ state_dict[new_key] = state_dict.pop(key)
+ return
+ return
+
+
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "motion_encoder": convert_animate_motion_encoder_weights,
+ "face_adapter": convert_animate_face_adapter_weights,
+}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
@@ -56,7 +325,7 @@ def load_sharded_safetensors(dir: pathlib.Path):
return state_dict
-def get_transformer_config(model_type: str) -> Dict[str, Any]:
+def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
if model_type == "Wan-T2V-1.3B":
config = {
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
@@ -76,6 +345,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-T2V-14B":
config = {
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
@@ -95,6 +366,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-I2V-14B-480p":
config = {
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
@@ -115,6 +388,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan-I2V-14B-720p":
config = {
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
@@ -135,33 +410,236 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
"text_dim": 4096,
},
}
- return config
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan-FLF2V-14B-720P":
+ config = {
+ "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
+ "diffusers_config": {
+ "image_dim": 1280,
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "rope_max_seq_len": 1024,
+ "pos_embed_seq_len": 257 * 2,
+ },
+ }
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan-VACE-1.3B":
+ config = {
+ "model_id": "Wan-AI/Wan2.1-VACE-1.3B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 12,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
+ "vace_in_channels": 96,
+ },
+ }
+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan-VACE-14B":
+ config = {
+ "model_id": "Wan-AI/Wan2.1-VACE-14B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
+ "vace_in_channels": 96,
+ },
+ }
+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-VACE-Fun-14B":
+ config = {
+ "model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
+ "vace_in_channels": 96,
+ },
+ }
+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-I2V-14B-720p":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-I2V-A14B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-T2V-A14B":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-T2V-A14B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-TI2V-5B":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-TI2V-5B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 14336,
+ "freq_dim": 256,
+ "in_channels": 48,
+ "num_attention_heads": 24,
+ "num_layers": 30,
+ "out_channels": 48,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-Animate-14B":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-Animate-14B",
+ "diffusers_config": {
+ "image_dim": 1280,
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": (1, 2, 2),
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "rope_max_seq_len": 1024,
+ "pos_embed_seq_len": None,
+ "motion_encoder_size": 512, # Start of Wan Animate-specific configs
+ "motion_style_dim": 512,
+ "motion_dim": 20,
+ "motion_encoder_dim": 512,
+ "face_encoder_hidden_dim": 1024,
+ "face_encoder_num_heads": 4,
+ "inject_face_latents_blocks": 5,
+ },
+ }
+ RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
+ return config, RENAME_DICT, SPECIAL_KEYS_REMAP
+
+def convert_transformer(model_type: str, stage: str = None):
+ config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
-def convert_transformer(model_type: str):
- config = get_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
model_id = config["model_id"]
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
+ if stage is not None:
+ model_dir = model_dir / stage
+
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
- transformer = WanTransformer3DModel.from_config(diffusers_config)
+ if "Animate" in model_type:
+ transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
+ elif "VACE" in model_type:
+ transformer = WanVACETransformer3DModel.from_config(diffusers_config)
+ else:
+ transformer = WanTransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
- for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ for replace_key, rename_key in RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
- for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
+ # Load state dict into the meta model, which will materialize the tensors
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+
+ # Move to CPU to ensure all tensors are materialized
+ transformer = transformer.to("cpu")
+
return transformer
@@ -368,11 +846,315 @@ def convert_vae():
return vae
+vae22_diffusers_config = {
+ "base_dim": 160,
+ "z_dim": 48,
+ "is_residual": True,
+ "in_channels": 12,
+ "out_channels": 12,
+ "decoder_base_dim": 256,
+ "scale_factor_temporal": 4,
+ "scale_factor_spatial": 16,
+ "patch_size": 2,
+ "latents_mean": [
+ -0.2289,
+ -0.0052,
+ -0.1323,
+ -0.2339,
+ -0.2799,
+ 0.0174,
+ 0.1838,
+ 0.1557,
+ -0.1382,
+ 0.0542,
+ 0.2813,
+ 0.0891,
+ 0.1570,
+ -0.0098,
+ 0.0375,
+ -0.1825,
+ -0.2246,
+ -0.1207,
+ -0.0698,
+ 0.5109,
+ 0.2665,
+ -0.2108,
+ -0.2158,
+ 0.2502,
+ -0.2055,
+ -0.0322,
+ 0.1109,
+ 0.1567,
+ -0.0729,
+ 0.0899,
+ -0.2799,
+ -0.1230,
+ -0.0313,
+ -0.1649,
+ 0.0117,
+ 0.0723,
+ -0.2839,
+ -0.2083,
+ -0.0520,
+ 0.3748,
+ 0.0152,
+ 0.1957,
+ 0.1433,
+ -0.2944,
+ 0.3573,
+ -0.0548,
+ -0.1681,
+ -0.0667,
+ ],
+ "latents_std": [
+ 0.4765,
+ 1.0364,
+ 0.4514,
+ 1.1677,
+ 0.5313,
+ 0.4990,
+ 0.4818,
+ 0.5013,
+ 0.8158,
+ 1.0344,
+ 0.5894,
+ 1.0901,
+ 0.6885,
+ 0.6165,
+ 0.8454,
+ 0.4978,
+ 0.5759,
+ 0.3523,
+ 0.7135,
+ 0.6804,
+ 0.5833,
+ 1.4146,
+ 0.8986,
+ 0.5659,
+ 0.7069,
+ 0.5338,
+ 0.4889,
+ 0.4917,
+ 0.4069,
+ 0.4999,
+ 0.6866,
+ 0.4093,
+ 0.5709,
+ 0.6065,
+ 0.6415,
+ 0.4944,
+ 0.5726,
+ 1.2042,
+ 0.5458,
+ 1.6887,
+ 0.3971,
+ 1.0600,
+ 0.3943,
+ 0.5537,
+ 0.5444,
+ 0.4089,
+ 0.7468,
+ 0.7744,
+ ],
+ "clip_output": False,
+}
+
+
+def convert_vae_22():
+ vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth")
+ old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
+ new_state_dict = {}
+
+ # Create mappings for specific components
+ middle_key_mapping = {
+ # Encoder middle block
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ # Decoder middle block
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ }
+
+ # Create a mapping for attention blocks
+ attention_mapping = {
+ # Encoder middle attention
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
+ # Decoder middle attention
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
+ }
+
+ # Create a mapping for the head components
+ head_mapping = {
+ # Encoder head
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
+ "encoder.head.2.bias": "encoder.conv_out.bias",
+ "encoder.head.2.weight": "encoder.conv_out.weight",
+ # Decoder head
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
+ "decoder.head.2.bias": "decoder.conv_out.bias",
+ "decoder.head.2.weight": "decoder.conv_out.weight",
+ }
+
+ # Create a mapping for the quant components
+ quant_mapping = {
+ "conv1.weight": "quant_conv.weight",
+ "conv1.bias": "quant_conv.bias",
+ "conv2.weight": "post_quant_conv.weight",
+ "conv2.bias": "post_quant_conv.bias",
+ }
+
+ # Process each key in the state dict
+ for key, value in old_state_dict.items():
+ # Handle middle block keys using the mapping
+ if key in middle_key_mapping:
+ new_key = middle_key_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle attention blocks using the mapping
+ elif key in attention_mapping:
+ new_key = attention_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle head keys using the mapping
+ elif key in head_mapping:
+ new_key = head_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle quant keys using the mapping
+ elif key in quant_mapping:
+ new_key = quant_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle encoder conv1
+ elif key == "encoder.conv1.weight":
+ new_state_dict["encoder.conv_in.weight"] = value
+ elif key == "encoder.conv1.bias":
+ new_state_dict["encoder.conv_in.bias"] = value
+ # Handle decoder conv1
+ elif key == "decoder.conv1.weight":
+ new_state_dict["decoder.conv_in.weight"] = value
+ elif key == "decoder.conv1.bias":
+ new_state_dict["decoder.conv_in.bias"] = value
+ # Handle encoder downsamples
+ elif key.startswith("encoder.downsamples."):
+ # Change encoder.downsamples to encoder.down_blocks
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
+
+ # Handle residual blocks - change downsamples to resnets and rename components
+ if "residual" in new_key or "shortcut" in new_key:
+ # Change the second downsamples to resnets
+ new_key = new_key.replace(".downsamples.", ".resnets.")
+
+ # Rename residual components
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+
+ # Handle resample blocks - change downsamples to downsampler and remove index
+ elif "resample" in new_key or "time_conv" in new_key:
+ # Change the second downsamples to downsampler and remove the index
+ parts = new_key.split(".")
+ # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
+ # We want to change it to: encoder.down_blocks.X.downsampler.resample...
+ if len(parts) >= 4 and parts[3] == "downsamples":
+ # Remove the index (parts[4]) and change downsamples to downsampler
+ new_parts = parts[:3] + ["downsampler"] + parts[5:]
+ new_key = ".".join(new_parts)
+
+ new_state_dict[new_key] = value
+
+ # Handle decoder upsamples
+ elif key.startswith("decoder.upsamples."):
+ # Change decoder.upsamples to decoder.up_blocks
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+
+ # Handle residual blocks - change upsamples to resnets and rename components
+ if "residual" in new_key or "shortcut" in new_key:
+ # Change the second upsamples to resnets
+ new_key = new_key.replace(".upsamples.", ".resnets.")
+
+ # Rename residual components
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+
+ # Handle resample blocks - change upsamples to upsampler and remove index
+ elif "resample" in new_key or "time_conv" in new_key:
+ # Change the second upsamples to upsampler and remove the index
+ parts = new_key.split(".")
+ # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
+ # We want to change it to: encoder.down_blocks.X.downsampler.resample...
+ if len(parts) >= 4 and parts[3] == "upsamples":
+ # Remove the index (parts[4]) and change upsamples to upsampler
+ new_parts = parts[:3] + ["upsampler"] + parts[5:]
+ new_key = ".".join(new_parts)
+
+ new_state_dict[new_key] = value
+ else:
+ # Keep other keys unchanged
+ new_state_dict[key] = value
+
+ with init_empty_weights():
+ vae = AutoencoderKLWan(**vae22_diffusers_config)
+ vae.load_state_dict(new_state_dict, strict=True, assign=True)
+ return vae
+
+
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--output_path", type=str, required=True)
- parser.add_argument("--dtype", default="fp32")
+ parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
return parser.parse_args()
@@ -386,18 +1168,67 @@ def get_args():
if __name__ == "__main__":
args = get_args()
- transformer = None
- dtype = DTYPE_MAPPING[args.dtype]
+ if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
+ transformer = convert_transformer(args.model_type, stage="high_noise_model")
+ transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
+ else:
+ transformer = convert_transformer(args.model_type)
+ transformer_2 = None
+
+ if "Wan2.2" in args.model_type and "TI2V" in args.model_type:
+ vae = convert_vae_22()
+ else:
+ vae = convert_vae()
- transformer = convert_transformer(args.model_type).to(dtype=dtype)
- vae = convert_vae()
- text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
+ text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ if "FLF2V" in args.model_type:
+ flow_shift = 16.0
+ elif "TI2V" in args.model_type or "Animate" in args.model_type:
+ flow_shift = 5.0
+ else:
+ flow_shift = 3.0
scheduler = UniPCMultistepScheduler(
- prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
+ prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
)
- if "I2V" in args.model_type:
+ # If user has specified "none", we keep the original dtypes of the state dict without any conversion
+ if args.dtype != "none":
+ dtype = DTYPE_MAPPING[args.dtype]
+ transformer.to(dtype)
+ if transformer_2 is not None:
+ transformer_2.to(dtype)
+
+ if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
+ pipe = WanImageToVideoPipeline(
+ transformer=transformer,
+ transformer_2=transformer_2,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ boundary_ratio=0.9,
+ )
+ elif "Wan2.2" and "T2V" in args.model_type:
+ pipe = WanPipeline(
+ transformer=transformer,
+ transformer_2=transformer_2,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ boundary_ratio=0.875,
+ )
+ elif "Wan2.2" and "TI2V" in args.model_type:
+ pipe = WanPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ expand_timesteps=True,
+ )
+ elif "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
@@ -411,6 +1242,39 @@ def get_args():
image_encoder=image_encoder,
image_processor=image_processor,
)
+ elif "Wan2.2-VACE" in args.model_type:
+ pipe = WanVACEPipeline(
+ transformer=transformer,
+ transformer_2=transformer_2,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ boundary_ratio=0.875,
+ )
+ elif "Wan-VACE" in args.model_type:
+ pipe = WanVACEPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ elif "Animate" in args.model_type:
+ image_encoder = CLIPVisionModel.from_pretrained(
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
+ )
+ image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+
+ pipe = WanAnimatePipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
else:
pipe = WanPipeline(
transformer=transformer,
diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py
index 23d45d3dd6ad..826b9b208181 100644
--- a/scripts/convert_wuerstchen.py
+++ b/scripts/convert_wuerstchen.py
@@ -55,8 +55,8 @@
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
-deocder = WuerstchenDiffNeXt()
-deocder.load_state_dict(state_dict)
+decoder = WuerstchenDiffNeXt()
+decoder.load_state_dict(state_dict)
# Prior
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
@@ -94,7 +94,7 @@
prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior")
decoder_pipeline = WuerstchenDecoderPipeline(
- text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
+ text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=decoder, scheduler=scheduler
)
decoder_pipeline.save_pretrained("warp-ai/wuerstchen")
@@ -103,7 +103,7 @@
# Decoder
text_encoder=gen_text_encoder,
tokenizer=gen_tokenizer,
- decoder=deocder,
+ decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
diff --git a/setup.py b/setup.py
index fdc166a81ecf..8d346ddfecca 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -102,7 +102,8 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.27.0",
+ "httpx<1.0.0",
+ "huggingface-hub>=0.34.0,<2.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
@@ -110,19 +111,19 @@
"jax>=0.4.1",
"jaxlib>=0.4.1",
"Jinja2",
- "k-diffusion>=0.0.12",
+ "k-diffusion==0.0.12",
"torchsde",
"note_seq",
"librosa",
"numpy",
"parameterized",
- "peft>=0.6.0",
+ "peft>=0.17.0",
"protobuf>=3.20.3,<4",
"pytest",
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
- "ruff==0.1.5",
+ "ruff==0.9.10",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19",
@@ -132,6 +133,7 @@
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
+ "nvidia_modelopt[hf]>=0.33.1",
"regex!=2019.12.17",
"requests",
"tensorboard",
@@ -142,6 +144,8 @@
"urllib3<=2.0.0",
"black",
"phonemizer",
+ "opencv-python",
+ "timm",
]
# this is a lookup table with items like:
@@ -215,7 +219,7 @@ def run(self):
extras = {}
extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
-extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft")
+extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft", "timm")
extras["test"] = deps_list(
"compel",
"GitPython",
@@ -243,6 +247,7 @@ def run(self):
extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
+extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
@@ -256,6 +261,7 @@ def run(self):
install_requires = [
deps["importlib_metadata"],
deps["filelock"],
+ deps["httpx"],
deps["huggingface-hub"],
deps["numpy"],
deps["regex"],
@@ -268,7 +274,7 @@ def run(self):
setup(
name="diffusers",
- version="0.33.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 4d373b2a5ded..81cdd31615df 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.33.0.dev0"
+__version__ = "0.36.0.dev0"
from typing import TYPE_CHECKING
@@ -13,7 +13,9 @@
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
+ is_nvidia_modelopt_available,
is_onnx_available,
+ is_opencv_available,
is_optimum_quanto_available,
is_scipy_available,
is_sentencepiece_available,
@@ -32,10 +34,13 @@
_import_structure = {
"configuration_utils": ["ConfigMixin"],
+ "guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
+ "modular_pipelines": [],
"pipelines": [],
+ "quantizers.pipe_quant_config": ["PipelineQuantizationConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
@@ -122,6 +127,18 @@
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
+try:
+ if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_nvidia_modelopt_objects
+
+ _import_structure["utils.dummy_nvidia_modelopt_objects"] = [
+ name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")
+
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -146,52 +163,94 @@
]
else:
+ _import_structure["guiders"].extend(
+ [
+ "AdaptiveProjectedGuidance",
+ "AdaptiveProjectedMixGuidance",
+ "AutoGuidance",
+ "BaseGuidance",
+ "ClassifierFreeGuidance",
+ "ClassifierFreeZeroStarGuidance",
+ "FrequencyDecoupledGuidance",
+ "PerturbedAttentionGuidance",
+ "SkipLayerGuidance",
+ "SmoothedEnergyGuidance",
+ "TangentialClassifierFreeGuidance",
+ ]
+ )
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
+ "FirstBlockCacheConfig",
"HookRegistry",
+ "LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
+ "SmoothedEnergyGuidanceConfig",
+ "TaylorSeerCacheConfig",
"apply_faster_cache",
+ "apply_first_block_cache",
+ "apply_layer_skip",
"apply_pyramid_attention_broadcast",
+ "apply_taylorseer_cache",
]
)
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
+ "AttentionBackendName",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
+ "AutoencoderKLCosmos",
+ "AutoencoderKLFlux2",
+ "AutoencoderKLHunyuanImage",
+ "AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
+ "AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
+ "AutoencoderKLQwenImage",
"AutoencoderKLTemporalDecoder",
"AutoencoderKLWan",
"AutoencoderOobleck",
"AutoencoderTiny",
+ "AutoModel",
+ "BriaFiboTransformer2DModel",
+ "BriaTransformer2DModel",
"CacheMixin",
+ "ChromaTransformer2DModel",
+ "ChronoEditTransformer3DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
"ConsisIDTransformer3DModel",
"ConsistencyDecoderVAE",
+ "ContextParallelConfig",
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
+ "CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
+ "Flux2Transformer2DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
+ "HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
+ "HunyuanImageTransformer2DModel",
+ "HunyuanVideo15Transformer3DModel",
+ "HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
+ "Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
@@ -202,19 +261,28 @@
"MultiAdapter",
"MultiControlNetModel",
"OmniGenTransformer2DModel",
+ "OvisImageTransformer2DModel",
+ "ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
+ "PRXTransformer2DModel",
+ "QwenImageControlNetModel",
+ "QwenImageMultiControlNetModel",
+ "QwenImageTransformer2DModel",
"SanaControlNetModel",
"SanaTransformer2DModel",
+ "SanaVideoTransformer3DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
+ "SkyReelsV2Transformer3DModel",
"SparseControlNetModel",
"StableAudioDiTModel",
"StableCascadeUNet",
"T2IAdapter",
"T5FilmDecoder",
"Transformer2DModel",
+ "TransformerTemporalModel",
"UNet1DModel",
"UNet2DConditionModel",
"UNet2DModel",
@@ -224,7 +292,19 @@
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
+ "WanAnimateTransformer3DModel",
"WanTransformer3DModel",
+ "WanVACETransformer3DModel",
+ "ZImageTransformer2DModel",
+ "attention_backend",
+ ]
+ )
+ _import_structure["modular_pipelines"].extend(
+ [
+ "ComponentsManager",
+ "ComponentSpec",
+ "ModularPipeline",
+ "ModularPipelineBlocks",
]
)
_import_structure["optimization"] = [
@@ -281,6 +361,7 @@
"EulerDiscreteScheduler",
"FlowMatchEulerDiscreteScheduler",
"FlowMatchHeunDiscreteScheduler",
+ "FlowMatchLCMScheduler",
"HeunDiscreteScheduler",
"IPNDMScheduler",
"KarrasVeScheduler",
@@ -344,6 +425,25 @@
]
else:
+ _import_structure["modular_pipelines"].extend(
+ [
+ "FluxAutoBlocks",
+ "FluxKontextAutoBlocks",
+ "FluxKontextModularPipeline",
+ "FluxModularPipeline",
+ "QwenImageAutoBlocks",
+ "QwenImageEditAutoBlocks",
+ "QwenImageEditModularPipeline",
+ "QwenImageEditPlusAutoBlocks",
+ "QwenImageEditPlusModularPipeline",
+ "QwenImageModularPipeline",
+ "StableDiffusionXLAutoBlocks",
+ "StableDiffusionXLModularPipeline",
+ "Wan22AutoBlocks",
+ "WanAutoBlocks",
+ "WanModularPipeline",
+ ]
+ )
_import_structure["pipelines"].extend(
[
"AllegroPipeline",
@@ -366,6 +466,11 @@
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
+ "BriaFiboPipeline",
+ "BriaPipeline",
+ "ChromaImg2ImgPipeline",
+ "ChromaPipeline",
+ "ChronoEditPipeline",
"CLIPImageProjection",
"CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
@@ -375,10 +480,15 @@
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
+ "Cosmos2TextToImagePipeline",
+ "Cosmos2VideoToWorldPipeline",
+ "CosmosTextToWorldPipeline",
+ "CosmosVideoToWorldPipeline",
"CycleDiffusionPipeline",
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
+ "Flux2Pipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
@@ -388,12 +498,20 @@
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
+ "FluxKontextInpaintPipeline",
+ "FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
+ "HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
+ "HunyuanImagePipeline",
+ "HunyuanImageRefinerPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
+ "HunyuanVideo15ImageToVideoPipeline",
+ "HunyuanVideo15Pipeline",
+ "HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoPipeline",
"I2VGenXLPipeline",
@@ -406,6 +524,10 @@
"ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
+ "Kandinsky5I2IPipeline",
+ "Kandinsky5I2VPipeline",
+ "Kandinsky5T2IPipeline",
+ "Kandinsky5T2VPipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
@@ -431,7 +553,9 @@
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
+ "LTXLatentUpsamplePipeline",
"LTXPipeline",
+ "LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
@@ -442,19 +566,38 @@
"MochiPipeline",
"MusicLDMPipeline",
"OmniGenPipeline",
+ "OvisImagePipeline",
"PaintByExamplePipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
+ "PRXPipeline",
+ "QwenImageControlNetInpaintPipeline",
+ "QwenImageControlNetPipeline",
+ "QwenImageEditInpaintPipeline",
+ "QwenImageEditPipeline",
+ "QwenImageEditPlusPipeline",
+ "QwenImageImg2ImgPipeline",
+ "QwenImageInpaintPipeline",
+ "QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
+ "SanaImageToVideoPipeline",
"SanaPAGPipeline",
"SanaPipeline",
+ "SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
+ "SanaVideoPipeline",
+ "SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
+ "SkyReelsV2DiffusionForcingImageToVideoPipeline",
+ "SkyReelsV2DiffusionForcingPipeline",
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline",
+ "SkyReelsV2ImageToVideoPipeline",
+ "SkyReelsV2Pipeline",
"StableAudioPipeline",
"StableAudioProjectionModel",
"StableCascadeCombinedPipeline",
@@ -531,16 +674,36 @@
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
"VideoToVideoSDPipeline",
+ "VisualClozeGenerationPipeline",
+ "VisualClozePipeline",
"VQDiffusionPipeline",
+ "WanAnimatePipeline",
"WanImageToVideoPipeline",
"WanPipeline",
+ "WanVACEPipeline",
"WanVideoToVideoPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
+ "ZImageImg2ImgPipeline",
+ "ZImagePipeline",
]
)
+
+try:
+ if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
+
+ _import_structure["utils.dummy_torch_and_transformers_and_opencv_objects"] = [
+ name for name in dir(dummy_torch_and_transformers_and_opencv_objects) if not name.startswith("_")
+ ]
+
+else:
+ _import_structure["pipelines"].extend(["ConsisIDPipeline"])
+
try:
if not (
is_torch_available()
@@ -720,6 +883,7 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
+ from .quantizers import PipelineQuantizationConfig
try:
if not is_bitsandbytes_available():
@@ -753,6 +917,14 @@
else:
from .quantizers.quantization_config import QuantoConfig
+ try:
+ if not is_nvidia_modelopt_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_nvidia_modelopt_objects import *
+ else:
+ from .quantizers.quantization_config import NVIDIAModelOptConfig
+
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -767,49 +939,89 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
+ from .guiders import (
+ AdaptiveProjectedGuidance,
+ AdaptiveProjectedMixGuidance,
+ AutoGuidance,
+ BaseGuidance,
+ ClassifierFreeGuidance,
+ ClassifierFreeZeroStarGuidance,
+ FrequencyDecoupledGuidance,
+ PerturbedAttentionGuidance,
+ SkipLayerGuidance,
+ SmoothedEnergyGuidance,
+ TangentialClassifierFreeGuidance,
+ )
from .hooks import (
FasterCacheConfig,
+ FirstBlockCacheConfig,
HookRegistry,
+ LayerSkipConfig,
PyramidAttentionBroadcastConfig,
+ SmoothedEnergyGuidanceConfig,
+ TaylorSeerCacheConfig,
apply_faster_cache,
+ apply_first_block_cache,
+ apply_layer_skip,
apply_pyramid_attention_broadcast,
+ apply_taylorseer_cache,
)
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
+ AttentionBackendName,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
+ AutoencoderKLCosmos,
+ AutoencoderKLFlux2,
+ AutoencoderKLHunyuanImage,
+ AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
+ AutoencoderKLHunyuanVideo15,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
+ AutoencoderKLQwenImage,
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderTiny,
+ AutoModel,
+ BriaFiboTransformer2DModel,
+ BriaTransformer2DModel,
CacheMixin,
+ ChromaTransformer2DModel,
+ ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
ConsistencyDecoderVAE,
+ ContextParallelConfig,
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
+ CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
+ Flux2Transformer2DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
+ HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
+ HunyuanImageTransformer2DModel,
+ HunyuanVideo15Transformer3DModel,
+ HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
+ Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
@@ -820,18 +1032,27 @@
MultiAdapter,
MultiControlNetModel,
OmniGenTransformer2DModel,
+ OvisImageTransformer2DModel,
+ ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
+ PRXTransformer2DModel,
+ QwenImageControlNetModel,
+ QwenImageMultiControlNetModel,
+ QwenImageTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
+ SanaVideoTransformer3DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
+ SkyReelsV2Transformer3DModel,
SparseControlNetModel,
StableAudioDiTModel,
T2IAdapter,
T5FilmDecoder,
Transformer2DModel,
+ TransformerTemporalModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
@@ -841,8 +1062,13 @@
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
+ WanAnimateTransformer3DModel,
WanTransformer3DModel,
+ WanVACETransformer3DModel,
+ ZImageTransformer2DModel,
+ attention_backend,
)
+ from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
@@ -897,6 +1123,7 @@
EulerDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
+ FlowMatchLCMScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
@@ -938,6 +1165,23 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
+ from .modular_pipelines import (
+ FluxAutoBlocks,
+ FluxKontextAutoBlocks,
+ FluxKontextModularPipeline,
+ FluxModularPipeline,
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditModularPipeline,
+ QwenImageEditPlusAutoBlocks,
+ QwenImageEditPlusModularPipeline,
+ QwenImageModularPipeline,
+ StableDiffusionXLAutoBlocks,
+ StableDiffusionXLModularPipeline,
+ Wan22AutoBlocks,
+ WanAutoBlocks,
+ WanModularPipeline,
+ )
from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
@@ -957,6 +1201,11 @@
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
+ BriaFiboPipeline,
+ BriaPipeline,
+ ChromaImg2ImgPipeline,
+ ChromaPipeline,
+ ChronoEditPipeline,
CLIPImageProjection,
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
@@ -966,10 +1215,15 @@
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
+ Cosmos2TextToImagePipeline,
+ Cosmos2VideoToWorldPipeline,
+ CosmosTextToWorldPipeline,
+ CosmosVideoToWorldPipeline,
CycleDiffusionPipeline,
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
+ Flux2Pipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
@@ -979,12 +1233,20 @@
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextInpaintPipeline,
+ FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
+ HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
+ HunyuanImagePipeline,
+ HunyuanImageRefinerPipeline,
HunyuanSkyreelsImageToVideoPipeline,
+ HunyuanVideo15ImageToVideoPipeline,
+ HunyuanVideo15Pipeline,
+ HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
I2VGenXLPipeline,
@@ -997,6 +1259,10 @@
ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
+ Kandinsky5I2IPipeline,
+ Kandinsky5I2VPipeline,
+ Kandinsky5T2IPipeline,
+ Kandinsky5T2VPipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
@@ -1022,7 +1288,9 @@
LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline,
+ LTXLatentUpsamplePipeline,
LTXPipeline,
+ LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
@@ -1033,19 +1301,37 @@
MochiPipeline,
MusicLDMPipeline,
OmniGenPipeline,
+ OvisImagePipeline,
PaintByExamplePipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
+ PRXPipeline,
+ QwenImageControlNetInpaintPipeline,
+ QwenImageControlNetPipeline,
+ QwenImageEditInpaintPipeline,
+ QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
+ QwenImageImg2ImgPipeline,
+ QwenImageInpaintPipeline,
+ QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
+ SanaImageToVideoPipeline,
SanaPAGPipeline,
SanaPipeline,
+ SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
+ SanaVideoPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
+ SkyReelsV2DiffusionForcingImageToVideoPipeline,
+ SkyReelsV2DiffusionForcingPipeline,
+ SkyReelsV2DiffusionForcingVideoToVideoPipeline,
+ SkyReelsV2ImageToVideoPipeline,
+ SkyReelsV2Pipeline,
StableAudioPipeline,
StableAudioProjectionModel,
StableCascadeCombinedPipeline,
@@ -1121,13 +1407,19 @@
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
+ VisualClozeGenerationPipeline,
+ VisualClozePipeline,
VQDiffusionPipeline,
+ WanAnimatePipeline,
WanImageToVideoPipeline,
WanPipeline,
+ WanVACEPipeline,
WanVideoToVideoPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
+ ZImageImg2ImgPipeline,
+ ZImagePipeline,
)
try:
@@ -1156,6 +1448,15 @@
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
else:
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
+
+ try:
+ if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_torch_and_transformers_and_opencv_objects import * # noqa F403
+ else:
+ from .pipelines import ConsisIDPipeline
+
try:
if not (
is_torch_available() and is_transformers_available() and is_onnx_available()
diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py
index 4b8b15368c47..2a08f091d9f3 100644
--- a/src/diffusers/callbacks.py
+++ b/src/diffusers/callbacks.py
@@ -207,3 +207,38 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
if step_index == cutoff_step:
pipeline.set_ip_adapter_scale(0.0)
return callback_kwargs
+
+
+class SD3CFGCutoffCallback(PipelineCallback):
+ """
+ Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
+ `cutoff_step_index`), this callback will disable the CFG.
+
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
+ """
+
+ tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
+
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
+
+ pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]]
+ pooled_prompt_embeds = pooled_prompt_embeds[
+ -1:
+ ] # "-1" denotes the embeddings for conditional pooled text tokens.
+
+ pipeline._guidance_scale = 0.0
+
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds
+ return callback_kwargs
diff --git a/src/diffusers/commands/__init__.py b/src/diffusers/commands/__init__.py
index 8208283f6e40..3a8c10147e8b 100644
--- a/src/diffusers/commands/__init__.py
+++ b/src/diffusers/commands/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py
new file mode 100644
index 000000000000..43d9ea88577a
--- /dev/null
+++ b/src/diffusers/commands/custom_blocks.py
@@ -0,0 +1,134 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage example:
+ TODO
+"""
+
+import ast
+import importlib.util
+import os
+from argparse import ArgumentParser, Namespace
+from pathlib import Path
+
+from ..utils import logging
+from . import BaseDiffusersCLICommand
+
+
+EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
+CONFIG = "config.json"
+
+
+def conversion_command_factory(args: Namespace):
+ return CustomBlocksCommand(args.block_module_name, args.block_class_name)
+
+
+class CustomBlocksCommand(BaseDiffusersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ conversion_parser = parser.add_parser("custom_blocks")
+ conversion_parser.add_argument(
+ "--block_module_name",
+ type=str,
+ default="block.py",
+ help="Module filename in which the custom block will be implemented.",
+ )
+ conversion_parser.add_argument(
+ "--block_class_name",
+ type=str,
+ default=None,
+ help="Name of the custom block. If provided None, we will try to infer it.",
+ )
+ conversion_parser.set_defaults(func=conversion_command_factory)
+
+ def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
+ self.logger = logging.get_logger("diffusers-cli/custom_blocks")
+ self.block_module_name = Path(block_module_name)
+ self.block_class_name = block_class_name
+
+ def run(self):
+ # determine the block to be saved.
+ out = self._get_class_names(self.block_module_name)
+ classes_found = list({cls for cls, _ in out})
+
+ if self.block_class_name is not None:
+ child_class, parent_class = self._choose_block(out, self.block_class_name)
+ if child_class is None and parent_class is None:
+ raise ValueError(
+ "`block_class_name` could not be retrieved. Available classes from "
+ f"{self.block_module_name}:\n{classes_found}"
+ )
+ else:
+ self.logger.info(
+ f"Found classes: {classes_found} will be using {classes_found[0]}. "
+ "If this needs to be changed, re-run the command specifying `block_class_name`."
+ )
+ child_class, parent_class = out[0][0], out[0][1]
+
+ # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
+ # the user is responsible for running it, so I guess that is safe?
+ module_name = f"__dynamic__{self.block_module_name.stem}"
+ spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ getattr(module, child_class)().save_pretrained(os.getcwd())
+
+ # or, we could create it manually.
+ # automap = self._create_automap(parent_class=parent_class, child_class=child_class)
+ # with open(CONFIG, "w") as f:
+ # json.dump(automap, f)
+ with open("requirements.txt", "w") as f:
+ f.write("")
+
+ def _choose_block(self, candidates, chosen=None):
+ for cls, base in candidates:
+ if cls == chosen:
+ return cls, base
+ return None, None
+
+ def _get_class_names(self, file_path):
+ source = file_path.read_text(encoding="utf-8")
+ try:
+ tree = ast.parse(source, filename=file_path)
+ except SyntaxError as e:
+ raise ValueError(f"Could not parse {file_path!r}: {e}") from e
+
+ results: list[tuple[str, str]] = []
+ for node in tree.body:
+ if not isinstance(node, ast.ClassDef):
+ continue
+
+ # extract all base names for this class
+ base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
+
+ # for each allowed base that appears in the class's bases, emit a tuple
+ for allowed in EXPECTED_PARENT_CLASSES:
+ if allowed in base_names:
+ results.append((node.name, allowed))
+
+ return results
+
+ def _get_base_name(self, node: ast.expr):
+ if isinstance(node, ast.Name):
+ return node.id
+ elif isinstance(node, ast.Attribute):
+ val = self._get_base_name(node.value)
+ return f"{val}.{node.attr}" if val else node.attr
+ return None
+
+ def _create_automap(self, parent_class, child_class):
+ module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
+ auto_map = {f"{parent_class}": f"{module}.{child_class}"}
+ return {"auto_map": auto_map}
diff --git a/src/diffusers/commands/diffusers_cli.py b/src/diffusers/commands/diffusers_cli.py
index f582c3bcd0df..a27ac24f2a3e 100644
--- a/src/diffusers/commands/diffusers_cli.py
+++ b/src/diffusers/commands/diffusers_cli.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
from argparse import ArgumentParser
+from .custom_blocks import CustomBlocksCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
@@ -26,6 +27,7 @@ def main():
# Register commands
EnvironmentCommand.register_subcommand(commands_parser)
FP16SafetensorsCommand.register_subcommand(commands_parser)
+ CustomBlocksCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()
diff --git a/src/diffusers/commands/env.py b/src/diffusers/commands/env.py
index d0af30bf1c65..58f31d478bf3 100644
--- a/src/diffusers/commands/env.py
+++ b/src/diffusers/commands/env.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/commands/fp16_safetensors.py b/src/diffusers/commands/fp16_safetensors.py
index b26b8816bc4c..41739261e553 100644
--- a/src/diffusers/commands/fp16_safetensors.py
+++ b/src/diffusers/commands/fp16_safetensors.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -59,7 +59,7 @@ def register_subcommand(parser: ArgumentParser):
conversion_parser.add_argument(
"--use_auth_token",
action="store_true",
- help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
+ help="When working with checkpoints having private visibility. When used `hf auth login` needs to be run beforehand.",
)
conversion_parser.set_defaults(func=conversion_command_factory)
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index f9b652bbc021..1c4ee33acbfd 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -30,11 +30,11 @@
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
-from requests import HTTPError
from typing_extensions import Self
from . import __version__
@@ -176,6 +176,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
+ subfolder = kwargs.pop("subfolder", None)
self._upload_folder(
save_directory,
@@ -183,6 +184,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool
token=token,
commit_message=commit_message,
create_pr=create_pr,
+ subfolder=subfolder,
)
@classmethod
@@ -405,7 +407,7 @@ def load_config(
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
- " token having permission to this repo with `token` or log in with `huggingface-cli login`."
+ " token having permission to this repo with `token` or log in with `hf auth login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
@@ -417,7 +419,7 @@ def load_config(
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
)
- except HTTPError as err:
+ except HfHubHTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
@@ -601,6 +603,10 @@ def to_json_saveable(value):
value = value.tolist()
elif isinstance(value, Path):
value = value.as_posix()
+ elif hasattr(value, "to_dict") and callable(value.to_dict):
+ value = value.to_dict()
+ elif isinstance(value, list):
+ value = [to_json_saveable(v) for v in value]
return value
if "quantization_config" in config_dict:
@@ -757,4 +763,7 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
- return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
+ if remapped_class is cls:
+ return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
+ else:
+ return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
diff --git a/src/diffusers/dependency_versions_check.py b/src/diffusers/dependency_versions_check.py
index 0728b3a7c093..e3670b136af4 100644
--- a/src/diffusers/dependency_versions_check.py
+++ b/src/diffusers/dependency_versions_check.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 8ec95ed6fc8d..6e5ac630ab08 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -9,7 +9,8 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
- "huggingface-hub": "huggingface-hub>=0.27.0",
+ "httpx": "httpx<1.0.0",
+ "huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
@@ -17,19 +18,19 @@
"jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
- "k-diffusion": "k-diffusion>=0.0.12",
+ "k-diffusion": "k-diffusion==0.0.12",
"torchsde": "torchsde",
"note_seq": "note_seq",
"librosa": "librosa",
"numpy": "numpy",
"parameterized": "parameterized",
- "peft": "peft>=0.6.0",
+ "peft": "peft>=0.17.0",
"protobuf": "protobuf>=3.20.3,<4",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
- "ruff": "ruff==0.1.5",
+ "ruff": "ruff==0.9.10",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19",
@@ -39,6 +40,7 @@
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
+ "nvidia_modelopt[hf]": "nvidia_modelopt[hf]>=0.33.1",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
@@ -49,4 +51,6 @@
"urllib3": "urllib3<=2.0.0",
"black": "black",
"phonemizer": "phonemizer",
+ "opencv-python": "opencv-python",
+ "timm": "timm",
}
diff --git a/src/diffusers/experimental/README.md b/src/diffusers/experimental/README.md
index 81a9de81c737..77594b14dbfc 100644
--- a/src/diffusers/experimental/README.md
+++ b/src/diffusers/experimental/README.md
@@ -2,4 +2,4 @@
We are adding experimental code to support novel applications and usages of the Diffusers library.
Currently, the following experiments are supported:
-* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
\ No newline at end of file
+* Reinforcement learning via an implementation of the [Diffuser](https://huggingface.co/papers/2205.09991) model.
\ No newline at end of file
diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py
index 2f9de857480e..c69d308ecc68 100644
--- a/src/diffusers/experimental/rl/value_guided_sampling.py
+++ b/src/diffusers/experimental/rl/value_guided_sampling.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py
new file mode 100644
index 000000000000..4e53c373c4f4
--- /dev/null
+++ b/src/diffusers/guiders/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+
+from ..utils import is_torch_available, logging
+
+
+if is_torch_available():
+ from .adaptive_projected_guidance import AdaptiveProjectedGuidance
+ from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
+ from .auto_guidance import AutoGuidance
+ from .classifier_free_guidance import ClassifierFreeGuidance
+ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
+ from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
+ from .guider_utils import BaseGuidance
+ from .perturbed_attention_guidance import PerturbedAttentionGuidance
+ from .skip_layer_guidance import SkipLayerGuidance
+ from .smoothed_energy_guidance import SmoothedEnergyGuidance
+ from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py
new file mode 100644
index 000000000000..8ec30d02d758
--- /dev/null
+++ b/src/diffusers/guiders/adaptive_projected_guidance.py
@@ -0,0 +1,235 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class AdaptiveProjectedGuidance(BaseGuidance):
+ """
+ Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ adaptive_projected_guidance_momentum (`float`, defaults to `None`):
+ The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
+ adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ adaptive_projected_guidance_momentum: Optional[float] = None,
+ adaptive_projected_guidance_rescale: float = 15.0,
+ eta: float = 1.0,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
+ self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
+ self.eta = eta
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+ self.momentum_buffer = None
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ if self._step == 0:
+ if self.adaptive_projected_guidance_momentum is not None:
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ if self._step == 0:
+ if self.adaptive_projected_guidance_momentum is not None:
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ if not self._is_apg_enabled():
+ pred = pred_cond
+ else:
+ pred = normalized_guidance(
+ pred_cond,
+ pred_uncond,
+ self.guidance_scale,
+ self.momentum_buffer,
+ self.eta,
+ self.adaptive_projected_guidance_rescale,
+ self.use_original_formulation,
+ )
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_apg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_apg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+
+class MomentumBuffer:
+ def __init__(self, momentum: float):
+ self.momentum = momentum
+ self.running_average = 0
+
+ def update(self, update_value: torch.Tensor):
+ new_average = self.momentum * self.running_average
+ self.running_average = update_value + new_average
+
+ def __repr__(self) -> str:
+ """
+ Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
+ """
+ if isinstance(self.running_average, torch.Tensor):
+ shape = tuple(self.running_average.shape)
+
+ # Calculate statistics
+ with torch.no_grad():
+ stats = {
+ "mean": self.running_average.mean().item(),
+ "std": self.running_average.std().item(),
+ "min": self.running_average.min().item(),
+ "max": self.running_average.max().item(),
+ }
+
+ # Get a slice (max 3 elements per dimension)
+ slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
+ sliced_data = self.running_average[slice_indices]
+
+ # Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
+ slice_str = str(sliced_data.detach().float().cpu().numpy())
+ if len(slice_str) > 200: # Truncate if too long
+ slice_str = slice_str[:200] + "..."
+
+ stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
+
+ return (
+ f"MomentumBuffer(\n"
+ f" momentum={self.momentum},\n"
+ f" shape={shape},\n"
+ f" stats=[{stats_str}],\n"
+ f" slice={slice_str}\n"
+ f")"
+ )
+ else:
+ return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
+
+
+def normalized_guidance(
+ pred_cond: torch.Tensor,
+ pred_uncond: torch.Tensor,
+ guidance_scale: float,
+ momentum_buffer: Optional[MomentumBuffer] = None,
+ eta: float = 1.0,
+ norm_threshold: float = 0.0,
+ use_original_formulation: bool = False,
+):
+ diff = pred_cond - pred_uncond
+ dim = [-i for i in range(1, len(diff.shape))]
+
+ if momentum_buffer is not None:
+ momentum_buffer.update(diff)
+ diff = momentum_buffer.running_average
+
+ if norm_threshold > 0:
+ ones = torch.ones_like(diff)
+ diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
+ diff = diff * scale_factor
+
+ v0, v1 = diff.double(), pred_cond.double()
+ v1 = torch.nn.functional.normalize(v1, dim=dim)
+ v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
+ v0_orthogonal = v0 - v0_parallel
+ diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
+ normalized_update = diff_orthogonal + eta * diff_parallel
+
+ pred = pred_cond if use_original_formulation else pred_uncond
+ pred = pred + guidance_scale * normalized_update
+
+ return pred
diff --git a/src/diffusers/guiders/adaptive_projected_guidance_mix.py b/src/diffusers/guiders/adaptive_projected_guidance_mix.py
new file mode 100644
index 000000000000..bdc97bcf6269
--- /dev/null
+++ b/src/diffusers/guiders/adaptive_projected_guidance_mix.py
@@ -0,0 +1,297 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class AdaptiveProjectedMixGuidance(BaseGuidance):
+ """
+ Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
+ (CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ adaptive_projected_guidance_momentum (`float`, defaults to `None`):
+ The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
+ adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
+ The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
+ improve image quality and fix
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
+ image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which the classifier-free guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which the classifier-free guidance stops.
+ adaptive_projected_guidance_start_step (`int`, defaults to `5`):
+ The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
+ used, and momentum buffer is updated).
+ enabled (`bool`, defaults to `True`):
+ Whether this guidance is enabled.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 3.5,
+ guidance_rescale: float = 0.0,
+ adaptive_projected_guidance_scale: float = 10.0,
+ adaptive_projected_guidance_momentum: float = -0.5,
+ adaptive_projected_guidance_rescale: float = 10.0,
+ eta: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ adaptive_projected_guidance_start_step: int = 5,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.guidance_rescale = guidance_rescale
+ self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
+ self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
+ self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
+ self.eta = eta
+ self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
+ self.use_original_formulation = use_original_formulation
+ self.momentum_buffer = None
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ if self._step == 0:
+ if self.adaptive_projected_guidance_momentum is not None:
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ if self._step == 0:
+ if self.adaptive_projected_guidance_momentum is not None:
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ # no guidance
+ if not self._is_cfg_enabled():
+ pred = pred_cond
+
+ # CFG + update momentum buffer
+ elif not self._is_apg_enabled():
+ if self.momentum_buffer is not None:
+ update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
+ # CFG + update momentum buffer
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+
+ # APG
+ elif self._is_apg_enabled():
+ pred = normalized_guidance(
+ pred_cond,
+ pred_uncond,
+ self.adaptive_projected_guidance_scale,
+ self.momentum_buffer,
+ self.eta,
+ self.adaptive_projected_guidance_rescale,
+ self.use_original_formulation,
+ )
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_apg_enabled() or self._is_cfg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ # Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ def _is_apg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ if not self._is_cfg_enabled():
+ return False
+
+ is_within_range = False
+ if self._step is not None:
+ is_within_range = self._step > self.adaptive_projected_guidance_start_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ def get_state(self):
+ state = super().get_state()
+ state["momentum_buffer"] = self.momentum_buffer
+ state["is_apg_enabled"] = self._is_apg_enabled()
+ state["is_cfg_enabled"] = self._is_cfg_enabled()
+ return state
+
+
+# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
+class MomentumBuffer:
+ def __init__(self, momentum: float):
+ self.momentum = momentum
+ self.running_average = 0
+
+ def update(self, update_value: torch.Tensor):
+ new_average = self.momentum * self.running_average
+ self.running_average = update_value + new_average
+
+ def __repr__(self) -> str:
+ """
+ Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
+ """
+ if isinstance(self.running_average, torch.Tensor):
+ shape = tuple(self.running_average.shape)
+
+ # Calculate statistics
+ with torch.no_grad():
+ stats = {
+ "mean": self.running_average.mean().item(),
+ "std": self.running_average.std().item(),
+ "min": self.running_average.min().item(),
+ "max": self.running_average.max().item(),
+ }
+
+ # Get a slice (max 3 elements per dimension)
+ slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
+ sliced_data = self.running_average[slice_indices]
+
+ # Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
+ slice_str = str(sliced_data.detach().float().cpu().numpy())
+ if len(slice_str) > 200: # Truncate if too long
+ slice_str = slice_str[:200] + "..."
+
+ stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
+
+ return (
+ f"MomentumBuffer(\n"
+ f" momentum={self.momentum},\n"
+ f" shape={shape},\n"
+ f" stats=[{stats_str}],\n"
+ f" slice={slice_str}\n"
+ f")"
+ )
+ else:
+ return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
+
+
+def update_momentum_buffer(
+ pred_cond: torch.Tensor,
+ pred_uncond: torch.Tensor,
+ momentum_buffer: Optional[MomentumBuffer] = None,
+):
+ diff = pred_cond - pred_uncond
+ if momentum_buffer is not None:
+ momentum_buffer.update(diff)
+
+
+def normalized_guidance(
+ pred_cond: torch.Tensor,
+ pred_uncond: torch.Tensor,
+ guidance_scale: float,
+ momentum_buffer: Optional[MomentumBuffer] = None,
+ eta: float = 1.0,
+ norm_threshold: float = 0.0,
+ use_original_formulation: bool = False,
+):
+ if momentum_buffer is not None:
+ update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
+ diff = momentum_buffer.running_average
+ else:
+ diff = pred_cond - pred_uncond
+
+ dim = [-i for i in range(1, len(diff.shape))]
+
+ if norm_threshold > 0:
+ ones = torch.ones_like(diff)
+ diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
+ diff = diff * scale_factor
+
+ v0, v1 = diff.double(), pred_cond.double()
+ v1 = torch.nn.functional.normalize(v1, dim=dim)
+ v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
+ v0_orthogonal = v0 - v0_parallel
+ diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
+ normalized_update = diff_orthogonal + eta * diff_parallel
+
+ pred = pred_cond if use_original_formulation else pred_uncond
+ pred = pred + guidance_scale * normalized_update
+
+ return pred
diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py
new file mode 100644
index 000000000000..b7f62e2f4a6e
--- /dev/null
+++ b/src/diffusers/guiders/auto_guidance.py
@@ -0,0 +1,196 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..hooks import HookRegistry, LayerSkipConfig
+from ..hooks.layer_skip import _apply_layer_skip_hook
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class AutoGuidance(BaseGuidance):
+ """
+ AutoGuidance: https://huggingface.co/papers/2406.02507
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ auto_guidance_layers (`int` or `List[int]`, *optional*):
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
+ provided, `skip_layer_config` must be provided.
+ auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
+ dropout (`float`, *optional*):
+ The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
+ `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ auto_guidance_layers: Optional[Union[int, List[int]]] = None,
+ auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
+ dropout: Optional[float] = None,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.auto_guidance_layers = auto_guidance_layers
+ self.auto_guidance_config = auto_guidance_config
+ self.dropout = dropout
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ is_layer_or_config_provided = auto_guidance_layers is not None or auto_guidance_config is not None
+ is_layer_and_config_provided = auto_guidance_layers is not None and auto_guidance_config is not None
+ if not is_layer_or_config_provided:
+ raise ValueError(
+ "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable AutoGuidance."
+ )
+ if is_layer_and_config_provided:
+ raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
+ if auto_guidance_config is None and dropout is None:
+ raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
+
+ if auto_guidance_layers is not None:
+ if isinstance(auto_guidance_layers, int):
+ auto_guidance_layers = [auto_guidance_layers]
+ if not isinstance(auto_guidance_layers, list):
+ raise ValueError(
+ f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
+ )
+ auto_guidance_config = [
+ LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
+ ]
+
+ if isinstance(auto_guidance_config, dict):
+ auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
+
+ if isinstance(auto_guidance_config, LayerSkipConfig):
+ auto_guidance_config = [auto_guidance_config]
+
+ if not isinstance(auto_guidance_config, list):
+ raise ValueError(
+ f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
+ )
+ elif isinstance(next(iter(auto_guidance_config), None), dict):
+ auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
+
+ self.auto_guidance_config = auto_guidance_config
+ self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
+
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ self._count_prepared += 1
+ if self._is_ag_enabled() and self.is_unconditional:
+ for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
+ _apply_layer_skip_hook(denoiser, config, name=name)
+
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
+ if self._is_ag_enabled() and self.is_unconditional:
+ for name in self._auto_guidance_hook_names:
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
+ registry.remove_hook(name, recurse=True)
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ if not self._is_ag_enabled():
+ pred = pred_cond
+ else:
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_ag_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_ag_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py
new file mode 100644
index 000000000000..5e55d4d869c1
--- /dev/null
+++ b/src/diffusers/guiders/classifier_free_guidance.py
@@ -0,0 +1,154 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class ClassifierFreeGuidance(BaseGuidance):
+ """
+ Implements Classifier-Free Guidance (CFG) for diffusion models.
+
+ Reference: https://huggingface.co/papers/2207.12598
+
+ CFG improves generation quality and prompt adherence by jointly training models on both conditional and
+ unconditional data, then combining predictions during inference. This allows trading off between quality (high
+ guidance) and diversity (low guidance).
+
+ **Two CFG Formulations:**
+
+ 1. **Original formulation** (from paper):
+ ```
+ x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
+ ```
+ Moves conditional predictions further from unconditional ones.
+
+ 2. **Diffusers-native formulation** (default, from Imagen paper):
+ ```
+ x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
+ ```
+ Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
+ quality", "watermarks"). Equivalent in theory but more intuitive.
+
+ Use `use_original_formulation=True` to switch to the original formulation.
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
+ may reduce quality. Typical range: 1.0-20.0.
+ guidance_rescale (`float`, defaults to `0.0`):
+ Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
+ Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
+ to 1.0 (full rescaling).
+ use_original_formulation (`bool`, defaults to `False`):
+ If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
+ diffusers-native formulation from the Imagen paper.
+ start (`float`, defaults to `0.0`):
+ Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
+ steps.
+ stop (`float`, defaults to `1.0`):
+ Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
+ steps.
+ enabled (`bool`, defaults to `True`):
+ Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ if not self._is_cfg_enabled():
+ pred = pred_cond
+ else:
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py
new file mode 100644
index 000000000000..23b492e51b02
--- /dev/null
+++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py
@@ -0,0 +1,162 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class ClassifierFreeZeroStarGuidance(BaseGuidance):
+ """
+ Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
+
+ This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
+ guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
+ process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
+ quality of generated images.
+
+ The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ zero_init_steps (`int`, defaults to `1`):
+ The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ zero_init_steps: int = 1,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.zero_init_steps = zero_init_steps
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ # YiYi Notes: add default behavior for self._enabled == False
+ if not self._enabled:
+ pred = pred_cond
+
+ elif self._step < self.zero_init_steps:
+ pred = torch.zeros_like(pred_cond)
+ elif not self._is_cfg_enabled():
+ pred = pred_cond
+ else:
+ pred_cond_flat = pred_cond.flatten(1)
+ pred_uncond_flat = pred_uncond.flatten(1)
+ alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
+ alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
+ pred_uncond = pred_uncond * alpha
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+
+def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
+ cond_dtype = cond.dtype
+ cond = cond.float()
+ uncond = uncond.float()
+ dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
+ squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
+ scale = dot_product / squared_norm
+ return scale.to(dtype=cond_dtype)
diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py
new file mode 100644
index 000000000000..4ec6e2d36da9
--- /dev/null
+++ b/src/diffusers/guiders/frequency_decoupled_guidance.py
@@ -0,0 +1,333 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..utils import is_kornia_available
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+_CAN_USE_KORNIA = is_kornia_available()
+
+
+if _CAN_USE_KORNIA:
+ from kornia.geometry import pyrup as upsample_and_blur_func
+ from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
+else:
+ upsample_and_blur_func = None
+ build_laplacian_pyramid_func = None
+
+
+def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
+ (Algorithm 2).
+ """
+ # v0 shape: [B, ...]
+ # v1 shape: [B, ...]
+ # Assume first dim is a batch dim and all other dims are channel or "spatial" dims
+ all_dims_but_first = list(range(1, len(v0.shape)))
+ if upcast_to_double:
+ dtype = v0.dtype
+ v0, v1 = v0.double(), v1.double()
+ v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
+ v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
+ v0_orthogonal = v0 - v0_parallel
+ if upcast_to_double:
+ v0_parallel = v0_parallel.to(dtype)
+ v0_orthogonal = v0_orthogonal.to(dtype)
+ return v0_parallel, v0_orthogonal
+
+
+def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
+ (Algorithm 2).
+ """
+ # pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
+ img = pyramid[-1]
+ for i in range(len(pyramid) - 2, -1, -1):
+ img = upsample_and_blur_func(img) + pyramid[i]
+ return img
+
+
+class FrequencyDecoupledGuidance(BaseGuidance):
+ """
+ Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
+
+ FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
+ quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
+ conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
+ how CFG works, you can check out the CFG guider.)
+
+ FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
+ using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
+ separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
+ frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
+ to form the final FDG prediction.
+
+ For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
+ diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
+ sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
+ the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
+ example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
+
+ As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
+ paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
+ theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
+
+ The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
+ paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
+
+ Args:
+ guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
+ The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
+ frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
+ values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
+ image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
+ lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
+ descending order).
+ guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
+ `guidance_scales`.
+ parallel_weights (`float` or `List[float]`, *optional*):
+ Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
+ set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
+ (that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
+ recommended. If a list is supplied, it should be the same length as `guidance_scales`.
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float` or `List[float]`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
+ should be the same length as `guidance_scales`.
+ stop (`float` or `List[float]`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
+ should be the same length as `guidance_scales`.
+ guidance_rescale_space (`str`, defaults to `"data"`):
+ Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
+ `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
+ speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
+ will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
+ upcast_to_double (`bool`, defaults to `True`):
+ Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
+ float64 when performing guidance. This may result in better performance at the cost of increased runtime.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
+ guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
+ parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
+ use_original_formulation: bool = False,
+ start: Union[float, List[float], Tuple[float]] = 0.0,
+ stop: Union[float, List[float], Tuple[float]] = 1.0,
+ guidance_rescale_space: str = "data",
+ upcast_to_double: bool = True,
+ enabled: bool = True,
+ ):
+ if not _CAN_USE_KORNIA:
+ raise ImportError(
+ "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
+ "it depends is not available in the current environment. You can install `kornia` with `pip install "
+ "kornia`."
+ )
+
+ # Set start to earliest start for any freq component and stop to latest stop for any freq component
+ min_start = start if isinstance(start, float) else min(start)
+ max_stop = stop if isinstance(stop, float) else max(stop)
+ super().__init__(min_start, max_stop, enabled)
+
+ self.guidance_scales = guidance_scales
+ self.levels = len(guidance_scales)
+
+ if isinstance(guidance_rescale, float):
+ self.guidance_rescale = [guidance_rescale] * self.levels
+ elif len(guidance_rescale) == self.levels:
+ self.guidance_rescale = guidance_rescale
+ else:
+ raise ValueError(
+ f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
+ f"`guidance_scales` ({len(self.guidance_scales)})"
+ )
+ # Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
+ # transforming from frequency space back to data space)
+ if guidance_rescale_space not in ["data", "freq"]:
+ raise ValueError(
+ f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
+ )
+ self.guidance_rescale_space = guidance_rescale_space
+
+ if parallel_weights is None:
+ # Use normal CFG shift (equal weights for parallel and orthogonal components)
+ self.parallel_weights = [1.0] * self.levels
+ elif isinstance(parallel_weights, float):
+ self.parallel_weights = [parallel_weights] * self.levels
+ elif len(parallel_weights) == self.levels:
+ self.parallel_weights = parallel_weights
+ else:
+ raise ValueError(
+ f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
+ f"`guidance_scales` ({len(self.guidance_scales)})"
+ )
+
+ self.use_original_formulation = use_original_formulation
+ self.upcast_to_double = upcast_to_double
+
+ if isinstance(start, float):
+ self.guidance_start = [start] * self.levels
+ elif len(start) == self.levels:
+ self.guidance_start = start
+ else:
+ raise ValueError(
+ f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
+ f"({len(self.guidance_scales)})"
+ )
+ if isinstance(stop, float):
+ self.guidance_stop = [stop] * self.levels
+ elif len(stop) == self.levels:
+ self.guidance_stop = stop
+ else:
+ raise ValueError(
+ f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
+ f"({len(self.guidance_scales)})"
+ )
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ if not self._is_fdg_enabled():
+ pred = pred_cond
+ else:
+ # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
+ pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
+ pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
+
+ # From high frequencies to low frequencies, following the paper implementation
+ pred_guided_pyramid = []
+ parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
+ for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
+ if self._is_fdg_enabled_for_level(level):
+ # Get the cond/uncond preds (in freq space) at the current frequency level
+ pred_cond_freq = pred_cond_pyramid[level]
+ pred_uncond_freq = pred_uncond_pyramid[level]
+
+ shift = pred_cond_freq - pred_uncond_freq
+
+ # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
+ if not math.isclose(parallel_weight, 1.0):
+ shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
+ shift = parallel_weight * shift_parallel + shift_orthogonal
+
+ # Apply CFG update for the current frequency level
+ pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
+ pred = pred + guidance_scale * shift
+
+ if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
+
+ # Add the current FDG guided level to the FDG prediction pyramid
+ pred_guided_pyramid.append(pred)
+ else:
+ # Add the current pred_cond_pyramid level as the "non-FDG" prediction
+ pred_guided_pyramid.append(pred_cond_freq)
+
+ # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
+ pred = build_image_from_pyramid(pred_guided_pyramid)
+
+ # If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
+ # across all freq levels
+ if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_fdg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_fdg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
+ else:
+ is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
+
+ return is_within_range and not is_close
+
+ def _is_fdg_enabled_for_level(self, level: int) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
+ skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scales[level], 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scales[level], 1.0)
+
+ return is_within_range and not is_close
diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py
new file mode 100644
index 000000000000..6c328328fc3b
--- /dev/null
+++ b/src/diffusers/guiders/guider_utils.py
@@ -0,0 +1,394 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from huggingface_hub.utils import validate_hf_hub_args
+from typing_extensions import Self
+
+from ..configuration_utils import ConfigMixin
+from ..utils import BaseOutput, PushToHubMixin, get_logger
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+GUIDER_CONFIG_NAME = "guider_config.json"
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+class BaseGuidance(ConfigMixin, PushToHubMixin):
+ r"""Base class providing the skeleton for implementing guidance techniques."""
+
+ config_name = GUIDER_CONFIG_NAME
+ _input_predictions = None
+ _identifier_key = "__guidance_identifier__"
+
+ def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
+ logger.warning(
+ "Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
+ )
+
+ self._start = start
+ self._stop = stop
+ self._step: int = None
+ self._num_inference_steps: int = None
+ self._timestep: torch.LongTensor = None
+ self._count_prepared = 0
+ self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
+ self._enabled = enabled
+
+ if not (0.0 <= start < 1.0):
+ raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
+ if not (start <= stop <= 1.0):
+ raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
+
+ if self._input_predictions is None or not isinstance(self._input_predictions, list):
+ raise ValueError(
+ "`_input_predictions` must be a list of required prediction names for the guidance technique."
+ )
+
+ def new(self, **kwargs):
+ """
+ Creates a copy of this guider instance, optionally with modified configuration parameters.
+
+ Args:
+ **kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
+ returns an exact copy with the same configuration.
+
+ Returns:
+ A new guider instance with the same (or updated) configuration.
+
+ Example:
+ ```python
+ # Create a CFG guider
+ guider = ClassifierFreeGuidance(guidance_scale=3.5)
+
+ # Create an exact copy
+ same_guider = guider.new()
+
+ # Create a copy with different start step, keeping other config the same
+ new_guider = guider.new(guidance_scale=5)
+ ```
+ """
+ return self.__class__.from_config(self.config, **kwargs)
+
+ def disable(self):
+ self._enabled = False
+
+ def enable(self):
+ self._enabled = True
+
+ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
+ self._step = step
+ self._num_inference_steps = num_inference_steps
+ self._timestep = timestep
+ self._count_prepared = 0
+
+ def get_state(self) -> Dict[str, Any]:
+ """
+ Returns the current state of the guidance technique as a dictionary. The state variables will be included in
+ the __repr__ method. Returns:
+ `Dict[str, Any]`: A dictionary containing the current state variables including:
+ - step: Current inference step
+ - num_inference_steps: Total number of inference steps
+ - timestep: Current timestep tensor
+ - count_prepared: Number of times prepare_models has been called
+ - enabled: Whether the guidance is enabled
+ - num_conditions: Number of conditions
+ """
+ state = {
+ "step": self._step,
+ "num_inference_steps": self._num_inference_steps,
+ "timestep": self._timestep,
+ "count_prepared": self._count_prepared,
+ "enabled": self._enabled,
+ "num_conditions": self.num_conditions,
+ }
+ return state
+
+ def __repr__(self) -> str:
+ """
+ Returns a string representation of the guidance object including both config and current state.
+ """
+ # Get ConfigMixin's __repr__
+ str_repr = super().__repr__()
+
+ # Get current state
+ state = self.get_state()
+
+ # Format each state variable on its own line with indentation
+ state_lines = []
+ for k, v in state.items():
+ # Convert value to string and handle multi-line values
+ v_str = str(v)
+ if "\n" in v_str:
+ # For multi-line values (like MomentumBuffer), indent subsequent lines
+ v_lines = v_str.split("\n")
+ v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
+ state_lines.append(f" {k}: {v_str}")
+
+ state_str = "\n".join(state_lines)
+
+ return f"{str_repr}\nState:\n{state_str}"
+
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ """
+ Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
+ subclasses to implement specific model preparation logic.
+ """
+ self._count_prepared += 1
+
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
+ """
+ Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
+ in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
+ modifications made during `prepare_models`.
+ """
+ pass
+
+ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
+ raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
+
+ def __call__(self, data: List["BlockState"]) -> Any:
+ if not all(hasattr(d, "noise_pred") for d in data):
+ raise ValueError("Expected all data to have `noise_pred` attribute.")
+ if len(data) != self.num_conditions:
+ raise ValueError(
+ f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
+ )
+ forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
+ return self.forward(**forward_inputs)
+
+ def forward(self, *args, **kwargs) -> Any:
+ raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
+
+ @property
+ def is_conditional(self) -> bool:
+ raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
+
+ @property
+ def is_unconditional(self) -> bool:
+ return not self.is_conditional
+
+ @property
+ def num_conditions(self) -> int:
+ raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
+
+ @classmethod
+ def _prepare_batch(
+ cls,
+ data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
+ tuple_index: int,
+ identifier: str,
+ ) -> "BlockState":
+ """
+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
+
+ Args:
+ input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
+ to look up the required data provided for preparation. If a string is provided, it will be used as the
+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
+ length 2 is provided, the first element must be the conditional data identifier and the second element
+ must be the unconditional data identifier or None.
+ data (`BlockState`):
+ The input data to be prepared.
+ tuple_index (`int`):
+ The index to use when accessing input fields that are tuples.
+
+ Returns:
+ `BlockState`: The prepared batch of data.
+ """
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+ data_batch = {}
+ for key, value in data.items():
+ try:
+ if isinstance(value, torch.Tensor):
+ data_batch[key] = value
+ elif isinstance(value, tuple):
+ data_batch[key] = value[tuple_index]
+ else:
+ raise ValueError(f"Invalid value type: {type(value)}")
+ except ValueError:
+ logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
+ data_batch[cls._identifier_key] = identifier
+ return BlockState(**data_batch)
+
+ @classmethod
+ def _prepare_batch_from_block_state(
+ cls,
+ input_fields: Dict[str, Union[str, Tuple[str, str]]],
+ data: "BlockState",
+ tuple_index: int,
+ identifier: str,
+ ) -> "BlockState":
+ """
+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
+
+ Args:
+ input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
+ to look up the required data provided for preparation. If a string is provided, it will be used as the
+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
+ length 2 is provided, the first element must be the conditional data identifier and the second element
+ must be the unconditional data identifier or None.
+ data (`BlockState`):
+ The input data to be prepared.
+ tuple_index (`int`):
+ The index to use when accessing input fields that are tuples.
+
+ Returns:
+ `BlockState`: The prepared batch of data.
+ """
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+ data_batch = {}
+ for key, value in input_fields.items():
+ try:
+ if isinstance(value, str):
+ data_batch[key] = getattr(data, value)
+ elif isinstance(value, tuple):
+ data_batch[key] = getattr(data, value[tuple_index])
+ else:
+ # We've already checked that value is a string or a tuple of strings with length 2
+ pass
+ except AttributeError:
+ logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
+ data_batch[cls._identifier_key] = identifier
+ return BlockState(**data_batch)
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ) -> Self:
+ r"""
+ Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
+ saved with [`~BaseGuidance.save_pretrained`].
+ subfolder (`str`, *optional*):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
+ firewalled environment.
+
+ """
+ config, kwargs, commit_hash = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ subfolder=subfolder,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ **kwargs,
+ )
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a guider configuration object to a directory so that it can be reloaded using the
+ [`~BaseGuidance.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+
+class GuiderOutput(BaseOutput):
+ pred: torch.Tensor
+ pred_cond: Optional[torch.Tensor]
+ pred_uncond: Optional[torch.Tensor]
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py
new file mode 100644
index 000000000000..f233e90ca410
--- /dev/null
+++ b/src/diffusers/guiders/perturbed_attention_guidance.py
@@ -0,0 +1,287 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..hooks import HookRegistry, LayerSkipConfig
+from ..hooks.layer_skip import _apply_layer_skip_hook
+from ..utils import get_logger
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+class PerturbedAttentionGuidance(BaseGuidance):
+ """
+ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
+
+ The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
+ worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
+ of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
+ attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
+ layers.
+
+ Additional reading:
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
+
+ PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
+ and implementation details.
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ perturbed_guidance_scale (`float`, defaults to `2.8`):
+ The scale parameter for perturbed attention guidance.
+ perturbed_guidance_start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which perturbed attention guidance starts.
+ perturbed_guidance_stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which perturbed attention guidance stops.
+ perturbed_guidance_layers (`int` or `List[int]`, *optional*):
+ The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
+ If not provided, `perturbed_guidance_config` must be provided.
+ perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
+ The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
+ `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ # NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
+ # the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
+ # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
+ # for each model architecture.
+
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ perturbed_guidance_scale: float = 2.8,
+ perturbed_guidance_start: float = 0.01,
+ perturbed_guidance_stop: float = 0.2,
+ perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
+ perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.skip_layer_guidance_scale = perturbed_guidance_scale
+ self.skip_layer_guidance_start = perturbed_guidance_start
+ self.skip_layer_guidance_stop = perturbed_guidance_stop
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ if perturbed_guidance_config is None:
+ if perturbed_guidance_layers is None:
+ raise ValueError(
+ "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
+ )
+ perturbed_guidance_config = LayerSkipConfig(
+ indices=perturbed_guidance_layers,
+ fqn="auto",
+ skip_attention=False,
+ skip_attention_scores=True,
+ skip_ff=False,
+ )
+ else:
+ if perturbed_guidance_layers is not None:
+ raise ValueError(
+ "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
+ )
+
+ if isinstance(perturbed_guidance_config, dict):
+ perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
+
+ if isinstance(perturbed_guidance_config, LayerSkipConfig):
+ perturbed_guidance_config = [perturbed_guidance_config]
+
+ if not isinstance(perturbed_guidance_config, list):
+ raise ValueError(
+ "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
+ )
+ elif isinstance(next(iter(perturbed_guidance_config), None), dict):
+ perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
+
+ for config in perturbed_guidance_config:
+ if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
+ logger.warning(
+ "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
+ "Please check your configuration. Modifying the config to match the expected values."
+ )
+ config.skip_attention = False
+ config.skip_attention_scores = True
+ config.skip_ff = False
+
+ self.skip_layer_config = perturbed_guidance_config
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ self._count_prepared += 1
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
+ _apply_layer_skip_hook(denoiser, config, name=name)
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
+ # Remove the hooks after inference
+ for hook_name in self._skip_layer_hook_names:
+ registry.remove_hook(hook_name, recurse=True)
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
+ def forward(
+ self,
+ pred_cond: torch.Tensor,
+ pred_uncond: Optional[torch.Tensor] = None,
+ pred_cond_skip: Optional[torch.Tensor] = None,
+ ) -> GuiderOutput:
+ pred = None
+
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
+ pred = pred_cond
+ elif not self._is_cfg_enabled():
+ shift = pred_cond - pred_cond_skip
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
+ pred = pred + self.skip_layer_guidance_scale * shift
+ elif not self._is_slg_enabled():
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+ else:
+ shift = pred_cond - pred_uncond
+ shift_skip = pred_cond - pred_cond_skip
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1 or self._count_prepared == 3
+
+ @property
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ if self._is_slg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
+ def _is_slg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
+ is_within_range = skip_start_step < self._step < skip_stop_step
+
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
+
+ return is_within_range and not is_zero
diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py
new file mode 100644
index 000000000000..e6109300d99c
--- /dev/null
+++ b/src/diffusers/guiders/skip_layer_guidance.py
@@ -0,0 +1,278 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..hooks import HookRegistry, LayerSkipConfig
+from ..hooks.layer_skip import _apply_layer_skip_hook
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class SkipLayerGuidance(BaseGuidance):
+ """
+ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
+
+ Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
+
+ SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
+ skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
+ batch of data, apart from the conditional and unconditional batches already used in CFG
+ ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
+ based on the difference between conditional without skipping and conditional with skipping predictions.
+
+ The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
+ worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
+ version of the model for the conditional prediction).
+
+ STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
+ generation quality in video diffusion models.
+
+ Additional reading:
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
+
+ The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
+ defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ skip_layer_guidance_scale (`float`, defaults to `2.8`):
+ The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
+ values, but it may also lead to overexposure and saturation.
+ skip_layer_guidance_start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which skip layer guidance starts.
+ skip_layer_guidance_stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which skip layer guidance stops.
+ skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
+ provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
+ 3.5 Medium.
+ skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ skip_layer_guidance_scale: float = 2.8,
+ skip_layer_guidance_start: float = 0.01,
+ skip_layer_guidance_stop: float = 0.2,
+ skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
+ skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.skip_layer_guidance_scale = skip_layer_guidance_scale
+ self.skip_layer_guidance_start = skip_layer_guidance_start
+ self.skip_layer_guidance_stop = skip_layer_guidance_stop
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ if not (0.0 <= skip_layer_guidance_start < 1.0):
+ raise ValueError(
+ f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
+ )
+ if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
+ raise ValueError(
+ f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
+ )
+
+ if skip_layer_guidance_layers is None and skip_layer_config is None:
+ raise ValueError(
+ "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
+ )
+ if skip_layer_guidance_layers is not None and skip_layer_config is not None:
+ raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
+
+ if skip_layer_guidance_layers is not None:
+ if isinstance(skip_layer_guidance_layers, int):
+ skip_layer_guidance_layers = [skip_layer_guidance_layers]
+ if not isinstance(skip_layer_guidance_layers, list):
+ raise ValueError(
+ f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
+ )
+ skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
+
+ if isinstance(skip_layer_config, dict):
+ skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
+
+ if isinstance(skip_layer_config, LayerSkipConfig):
+ skip_layer_config = [skip_layer_config]
+
+ if not isinstance(skip_layer_config, list):
+ raise ValueError(
+ f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
+ )
+ elif isinstance(next(iter(skip_layer_config), None), dict):
+ skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
+
+ self.skip_layer_config = skip_layer_config
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
+
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ self._count_prepared += 1
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
+ _apply_layer_skip_hook(denoiser, config, name=name)
+
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
+ # Remove the hooks after inference
+ for hook_name in self._skip_layer_hook_names:
+ registry.remove_hook(hook_name, recurse=True)
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(
+ self,
+ pred_cond: torch.Tensor,
+ pred_uncond: Optional[torch.Tensor] = None,
+ pred_cond_skip: Optional[torch.Tensor] = None,
+ ) -> GuiderOutput:
+ pred = None
+
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
+ pred = pred_cond
+ elif not self._is_cfg_enabled():
+ shift = pred_cond - pred_cond_skip
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
+ pred = pred + self.skip_layer_guidance_scale * shift
+ elif not self._is_slg_enabled():
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+ else:
+ shift = pred_cond - pred_uncond
+ shift_skip = pred_cond - pred_cond_skip
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1 or self._count_prepared == 3
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ if self._is_slg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ def _is_slg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
+ is_within_range = skip_start_step < self._step < skip_stop_step
+
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
+
+ return is_within_range and not is_zero
diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py
new file mode 100644
index 000000000000..6c3906e820e0
--- /dev/null
+++ b/src/diffusers/guiders/smoothed_energy_guidance.py
@@ -0,0 +1,267 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..hooks import HookRegistry
+from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class SmoothedEnergyGuidance(BaseGuidance):
+ """
+ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
+
+ SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
+ future without warning or guarantee of reproducibility. This implementation assumes:
+ - Generated images are square (height == width)
+ - The model does not combine different modalities together (e.g., text and image latent streams are not combined
+ together such as Flux)
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ seg_guidance_scale (`float`, defaults to `3.0`):
+ The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
+ values, but it may also lead to overexposure and saturation.
+ seg_blur_sigma (`float`, defaults to `9999999.0`):
+ The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
+ infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
+ seg_blur_threshold_inf (`float`, defaults to `9999.0`):
+ The threshold above which the blur is considered infinite.
+ seg_guidance_start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which smoothed energy guidance starts.
+ seg_guidance_stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which smoothed energy guidance stops.
+ seg_guidance_layers (`int` or `List[int]`, *optional*):
+ The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
+ not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
+ Diffusion 3.5 Medium.
+ seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
+ The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
+ a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ seg_guidance_scale: float = 2.8,
+ seg_blur_sigma: float = 9999999.0,
+ seg_blur_threshold_inf: float = 9999.0,
+ seg_guidance_start: float = 0.0,
+ seg_guidance_stop: float = 1.0,
+ seg_guidance_layers: Optional[Union[int, List[int]]] = None,
+ seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.seg_guidance_scale = seg_guidance_scale
+ self.seg_blur_sigma = seg_blur_sigma
+ self.seg_blur_threshold_inf = seg_blur_threshold_inf
+ self.seg_guidance_start = seg_guidance_start
+ self.seg_guidance_stop = seg_guidance_stop
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ if not (0.0 <= seg_guidance_start < 1.0):
+ raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
+ if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
+ raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
+
+ if seg_guidance_layers is None and seg_guidance_config is None:
+ raise ValueError(
+ "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
+ )
+ if seg_guidance_layers is not None and seg_guidance_config is not None:
+ raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
+
+ if seg_guidance_layers is not None:
+ if isinstance(seg_guidance_layers, int):
+ seg_guidance_layers = [seg_guidance_layers]
+ if not isinstance(seg_guidance_layers, list):
+ raise ValueError(
+ f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
+ )
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
+
+ if isinstance(seg_guidance_config, dict):
+ seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
+
+ if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
+ seg_guidance_config = [seg_guidance_config]
+
+ if not isinstance(seg_guidance_config, list):
+ raise ValueError(
+ f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
+ )
+ elif isinstance(next(iter(seg_guidance_config), None), dict):
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
+
+ self.seg_guidance_config = seg_guidance_config
+ self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
+
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
+ for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
+ _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
+
+ def cleanup_models(self, denoiser: torch.nn.Module):
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
+ # Remove the hooks after inference
+ for hook_name in self._seg_layer_hook_names:
+ registry.remove_hook(hook_name, recurse=True)
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(
+ self,
+ pred_cond: torch.Tensor,
+ pred_uncond: Optional[torch.Tensor] = None,
+ pred_cond_seg: Optional[torch.Tensor] = None,
+ ) -> GuiderOutput:
+ pred = None
+
+ if not self._is_cfg_enabled() and not self._is_seg_enabled():
+ pred = pred_cond
+ elif not self._is_cfg_enabled():
+ shift = pred_cond - pred_cond_seg
+ pred = pred_cond if self.use_original_formulation else pred_cond_seg
+ pred = pred + self.seg_guidance_scale * shift
+ elif not self._is_seg_enabled():
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+ else:
+ shift = pred_cond - pred_uncond
+ shift_seg = pred_cond - pred_cond_seg
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1 or self._count_prepared == 3
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ if self._is_seg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ def _is_seg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
+ skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
+ is_within_range = skip_start_step < self._step < skip_stop_step
+
+ is_zero = math.isclose(self.seg_guidance_scale, 0.0)
+
+ return is_within_range and not is_zero
diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py
new file mode 100644
index 000000000000..76899c6e8494
--- /dev/null
+++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py
@@ -0,0 +1,149 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class TangentialClassifierFreeGuidance(BaseGuidance):
+ """
+ Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def prepare_inputs_from_block_state(
+ self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
+ ) -> List["BlockState"]:
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ if not self._is_tcfg_enabled():
+ pred = pred_cond
+ else:
+ pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._num_outputs_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_tcfg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_tcfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+
+def normalized_guidance(
+ pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
+) -> torch.Tensor:
+ cond_dtype = pred_cond.dtype
+ preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
+ preds = preds.flatten(2)
+ U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
+ Vh_modified = Vh.clone()
+ Vh_modified[:, 1] = 0
+
+ uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
+ x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
+ x_Vh_V = torch.matmul(x_Vh, Vh_modified)
+ pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
+
+ pred = pred_cond if use_original_formulation else pred_uncond
+ shift = pred_cond - pred_uncond
+ pred = pred + guidance_scale * shift
+
+ return pred
diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py
index 764ceb25b465..eb12b8a52a1e 100644
--- a/src/diffusers/hooks/__init__.py
+++ b/src/diffusers/hooks/__init__.py
@@ -1,9 +1,28 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from ..utils import is_torch_available
if is_torch_available():
+ from .context_parallel import apply_context_parallel
from .faster_cache import FasterCacheConfig, apply_faster_cache
+ from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
+ from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
+ from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
+ from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py
new file mode 100644
index 000000000000..ca7934e5c313
--- /dev/null
+++ b/src/diffusers/hooks/_common.py
@@ -0,0 +1,56 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+
+from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
+from ..models.attention_processor import Attention, MochiAttention
+
+
+_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
+_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
+
+_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
+_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
+_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
+
+_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
+ {
+ *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
+ }
+)
+
+# Layers supported for group offloading and layerwise casting
+_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
+ torch.nn.Conv1d,
+ torch.nn.Conv2d,
+ torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d,
+ torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d,
+ torch.nn.Linear,
+ # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
+ # because of double invocation of the same norm layer in CogVideoXLayerNorm
+)
+
+
+def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
+ for submodule_name, submodule in module.named_modules():
+ if submodule_name == fqn:
+ return submodule
+ return None
diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py
new file mode 100644
index 000000000000..da7313cb4737
--- /dev/null
+++ b/src/diffusers/hooks/_helpers.py
@@ -0,0 +1,361 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Type
+
+
+@dataclass
+class AttentionProcessorMetadata:
+ skip_processor_output_fn: Callable[[Any], Any]
+
+
+@dataclass
+class TransformerBlockMetadata:
+ return_hidden_states_index: int = None
+ return_encoder_hidden_states_index: int = None
+
+ _cls: Type = None
+ _cached_parameter_indices: Dict[str, int] = None
+
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
+ kwargs = kwargs or {}
+ if identifier in kwargs:
+ return kwargs[identifier]
+ if self._cached_parameter_indices is not None:
+ return args[self._cached_parameter_indices[identifier]]
+ if self._cls is None:
+ raise ValueError("Model class is not set for metadata.")
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
+ parameters = parameters[1:] # skip `self`
+ self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
+ if identifier not in self._cached_parameter_indices:
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
+ index = self._cached_parameter_indices[identifier]
+ if index >= len(args):
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
+ return args[index]
+
+
+class AttentionProcessorRegistry:
+ _registry = {}
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
+ # import errors because of the models imported in this file.
+ _is_registered = False
+
+ @classmethod
+ def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
+ cls._register()
+ cls._registry[model_class] = metadata
+
+ @classmethod
+ def get(cls, model_class: Type) -> AttentionProcessorMetadata:
+ cls._register()
+ if model_class not in cls._registry:
+ raise ValueError(f"Model class {model_class} not registered.")
+ return cls._registry[model_class]
+
+ @classmethod
+ def _register(cls):
+ if cls._is_registered:
+ return
+ cls._is_registered = True
+ _register_attention_processors_metadata()
+
+
+class TransformerBlockRegistry:
+ _registry = {}
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
+ # import errors because of the models imported in this file.
+ _is_registered = False
+
+ @classmethod
+ def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
+ cls._register()
+ metadata._cls = model_class
+ cls._registry[model_class] = metadata
+
+ @classmethod
+ def get(cls, model_class: Type) -> TransformerBlockMetadata:
+ cls._register()
+ if model_class not in cls._registry:
+ raise ValueError(f"Model class {model_class} not registered.")
+ return cls._registry[model_class]
+
+ @classmethod
+ def _register(cls):
+ if cls._is_registered:
+ return
+ cls._is_registered = True
+ _register_transformer_blocks_metadata()
+
+
+def _register_attention_processors_metadata():
+ from ..models.attention_processor import AttnProcessor2_0
+ from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
+ from ..models.transformers.transformer_flux import FluxAttnProcessor
+ from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
+ from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
+ from ..models.transformers.transformer_wan import WanAttnProcessor2_0
+ from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
+
+ # AttnProcessor2_0
+ AttentionProcessorRegistry.register(
+ model_class=AttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
+ ),
+ )
+
+ # CogView4AttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=CogView4AttnProcessor,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
+ ),
+ )
+
+ # WanAttnProcessor2_0
+ AttentionProcessorRegistry.register(
+ model_class=WanAttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
+ ),
+ )
+
+ # FluxAttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=FluxAttnProcessor,
+ metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
+ )
+
+ # QwenDoubleStreamAttnProcessor2
+ AttentionProcessorRegistry.register(
+ model_class=QwenDoubleStreamAttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
+ ),
+ )
+
+ # HunyuanImageAttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=HunyuanImageAttnProcessor,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
+ ),
+ )
+
+ # ZSingleStreamAttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=ZSingleStreamAttnProcessor,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
+ ),
+ )
+
+
+def _register_transformer_blocks_metadata():
+ from ..models.attention import BasicTransformerBlock
+ from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
+ from ..models.transformers.transformer_bria import BriaTransformerBlock
+ from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
+ from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+ from ..models.transformers.transformer_hunyuan_video import (
+ HunyuanVideoSingleTransformerBlock,
+ HunyuanVideoTokenReplaceSingleTransformerBlock,
+ HunyuanVideoTokenReplaceTransformerBlock,
+ HunyuanVideoTransformerBlock,
+ )
+ from ..models.transformers.transformer_hunyuanimage import (
+ HunyuanImageSingleTransformerBlock,
+ HunyuanImageTransformerBlock,
+ )
+ from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
+ from ..models.transformers.transformer_mochi import MochiTransformerBlock
+ from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
+ from ..models.transformers.transformer_wan import WanTransformerBlock
+ from ..models.transformers.transformer_z_image import ZImageTransformerBlock
+
+ # BasicTransformerBlock
+ TransformerBlockRegistry.register(
+ model_class=BasicTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=BriaTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # CogVideoX
+ TransformerBlockRegistry.register(
+ model_class=CogVideoXBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # CogView4
+ TransformerBlockRegistry.register(
+ model_class=CogView4TransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # Flux
+ TransformerBlockRegistry.register(
+ model_class=FluxTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=FluxSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+
+ # HunyuanVideo
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTokenReplaceTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # LTXVideo
+ TransformerBlockRegistry.register(
+ model_class=LTXVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # Mochi
+ TransformerBlockRegistry.register(
+ model_class=MochiTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # Wan
+ TransformerBlockRegistry.register(
+ model_class=WanTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # QwenImage
+ TransformerBlockRegistry.register(
+ model_class=QwenImageTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+
+ # HunyuanImage2.1
+ TransformerBlockRegistry.register(
+ model_class=HunyuanImageTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanImageSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # ZImage
+ TransformerBlockRegistry.register(
+ model_class=ZImageTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+
+# fmt: off
+def _skip_attention___ret___hidden_states(self, *args, **kwargs):
+ hidden_states = kwargs.get("hidden_states", None)
+ if hidden_states is None and len(args) > 0:
+ hidden_states = args[0]
+ return hidden_states
+
+
+def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
+ hidden_states = kwargs.get("hidden_states", None)
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
+ if hidden_states is None and len(args) > 0:
+ hidden_states = args[0]
+ if encoder_hidden_states is None and len(args) > 1:
+ encoder_hidden_states = args[1]
+ return hidden_states, encoder_hidden_states
+
+
+_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
+_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
+# not sure what this is yet.
+_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
+# fmt: on
diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py
new file mode 100644
index 000000000000..6491d17b4f46
--- /dev/null
+++ b/src/diffusers/hooks/context_parallel.py
@@ -0,0 +1,302 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Dict, List, Type, Union
+
+import torch
+
+
+if torch.distributed.is_available():
+ import torch.distributed._functional_collectives as funcol
+
+from ..models._modeling_parallel import (
+ ContextParallelConfig,
+ ContextParallelInput,
+ ContextParallelModelPlan,
+ ContextParallelOutput,
+)
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
+_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
+
+
+# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
+@dataclass
+class ModuleForwardMetadata:
+ cached_parameter_indices: Dict[str, int] = None
+ _cls: Type = None
+
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
+ kwargs = kwargs or {}
+
+ if identifier in kwargs:
+ return kwargs[identifier], True, None
+
+ if self.cached_parameter_indices is not None:
+ index = self.cached_parameter_indices.get(identifier, None)
+ if index is None:
+ raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
+ return args[index], False, index
+
+ if self._cls is None:
+ raise ValueError("Model class is not set for metadata.")
+
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
+ parameters = parameters[1:] # skip `self`
+ self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
+
+ if identifier not in self.cached_parameter_indices:
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
+
+ index = self.cached_parameter_indices[identifier]
+
+ if index >= len(args):
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
+
+ return args[index], False, index
+
+
+def apply_context_parallel(
+ module: torch.nn.Module,
+ parallel_config: ContextParallelConfig,
+ plan: Dict[str, ContextParallelModelPlan],
+) -> None:
+ """Apply context parallel on a model."""
+ logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
+
+ for module_id, cp_model_plan in plan.items():
+ submodule = _get_submodule_by_name(module, module_id)
+ if not isinstance(submodule, list):
+ submodule = [submodule]
+
+ logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
+
+ for m in submodule:
+ if isinstance(cp_model_plan, dict):
+ hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
+ hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
+ elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
+ if isinstance(cp_model_plan, ContextParallelOutput):
+ cp_model_plan = [cp_model_plan]
+ if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
+ raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
+ hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
+ hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
+ else:
+ raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
+ registry = HookRegistry.check_if_exists_or_initialize(m)
+ registry.register_hook(hook, hook_name)
+
+
+def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
+ for module_id, cp_model_plan in plan.items():
+ submodule = _get_submodule_by_name(module, module_id)
+ if not isinstance(submodule, list):
+ submodule = [submodule]
+
+ for m in submodule:
+ registry = HookRegistry.check_if_exists_or_initialize(m)
+ if isinstance(cp_model_plan, dict):
+ hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
+ elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
+ hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
+ else:
+ raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
+ registry.remove_hook(hook_name)
+
+
+class ContextParallelSplitHook(ModelHook):
+ def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
+ super().__init__()
+ self.metadata = metadata
+ self.parallel_config = parallel_config
+ self.module_forward_metadata = None
+
+ def initialize_hook(self, module):
+ cls = unwrap_module(module).__class__
+ self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
+ return module
+
+ def pre_forward(self, module, *args, **kwargs):
+ args_list = list(args)
+
+ for name, cpm in self.metadata.items():
+ if isinstance(cpm, ContextParallelInput) and cpm.split_output:
+ continue
+
+ # Maybe the parameter was passed as a keyword argument
+ input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
+ name, args_list, kwargs
+ )
+
+ if input_val is None:
+ continue
+
+ # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
+ # the output instead of input for a particular layer by setting split_output=True
+ if isinstance(input_val, torch.Tensor):
+ input_val = self._prepare_cp_input(input_val, cpm)
+ elif isinstance(input_val, (list, tuple)):
+ if len(input_val) != len(cpm):
+ raise ValueError(
+ f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
+ )
+ sharded_input_val = []
+ for i, x in enumerate(input_val):
+ if torch.is_tensor(x) and not cpm[i].split_output:
+ x = self._prepare_cp_input(x, cpm[i])
+ sharded_input_val.append(x)
+ input_val = sharded_input_val
+ else:
+ raise ValueError(f"Unsupported input type: {type(input_val)}")
+
+ if is_kwarg:
+ kwargs[name] = input_val
+ elif index is not None and index < len(args_list):
+ args_list[index] = input_val
+ else:
+ raise ValueError(
+ f"An unexpected error occurred while processing the input '{name}'. Please open an "
+ f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
+ f"example along with the full stack trace."
+ )
+
+ return tuple(args_list), kwargs
+
+ def post_forward(self, module, output):
+ is_tensor = isinstance(output, torch.Tensor)
+ is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
+
+ if not is_tensor and not is_tensor_list:
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
+
+ output = [output] if is_tensor else list(output)
+ for index, cpm in self.metadata.items():
+ if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
+ continue
+ if index >= len(output):
+ raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
+ current_output = output[index]
+ current_output = self._prepare_cp_input(current_output, cpm)
+ output[index] = current_output
+
+ return output[0] if is_tensor else tuple(output)
+
+ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
+ if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
+ logger.warning_once(
+ f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
+ )
+ return x
+ else:
+ return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
+
+
+class ContextParallelGatherHook(ModelHook):
+ def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
+ super().__init__()
+ self.metadata = metadata
+ self.parallel_config = parallel_config
+
+ def post_forward(self, module, output):
+ is_tensor = isinstance(output, torch.Tensor)
+
+ if is_tensor:
+ output = [output]
+ elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
+
+ output = list(output)
+
+ if len(output) != len(self.metadata):
+ raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
+
+ for i, cpm in enumerate(self.metadata):
+ if cpm is None:
+ continue
+ output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
+
+ return output[0] if is_tensor else tuple(output)
+
+
+class AllGatherFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, tensor, dim, group):
+ ctx.dim = dim
+ ctx.group = group
+ ctx.world_size = torch.distributed.get_world_size(group)
+ ctx.rank = torch.distributed.get_rank(group)
+ return funcol.all_gather_tensor(tensor, dim, group=group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
+ return grad_chunks[ctx.rank], None, None
+
+
+class EquipartitionSharder:
+ @classmethod
+ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ # NOTE: the following assertion does not have to be true in general. We simply enforce it for now
+ # because the alternate case has not yet been tested/required for any model.
+ assert tensor.size()[dim] % mesh.size() == 0, (
+ "Tensor size along dimension to be sharded must be divisible by mesh size"
+ )
+
+ # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
+ # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
+
+ return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
+
+ @classmethod
+ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ tensor = tensor.contiguous()
+ tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
+ return tensor
+
+
+def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ if name.count("*") > 1:
+ raise ValueError("Wildcard '*' can only be used once in the name")
+ return _find_submodule_by_name(model, name)
+
+
+def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ if name == "":
+ return model
+ first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
+ if first_atom == "*":
+ if not isinstance(model, torch.nn.ModuleList):
+ raise ValueError("Wildcard '*' can only be used with ModuleList")
+ submodules = []
+ for submodule in model:
+ subsubmodules = _find_submodule_by_name(submodule, remaining_name)
+ if not isinstance(subsubmodules, list):
+ subsubmodules = [subsubmodules]
+ submodules.extend(subsubmodules)
+ return submodules
+ else:
+ if hasattr(model, first_atom):
+ submodule = getattr(model, first_atom)
+ return _find_submodule_by_name(submodule, remaining_name)
+ else:
+ raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py
index 634635346474..a01afeffdb95 100644
--- a/src/diffusers/hooks/faster_cache.py
+++ b/src/diffusers/hooks/faster_cache.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,9 +18,10 @@
import torch
-from ..models.attention_processor import Attention, MochiAttention
+from ..models.attention import AttentionModuleMixin
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
+from ._common import _ATTENTION_CLASSES
from .hooks import HookRegistry, ModelHook
@@ -29,7 +30,6 @@
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
-_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
@@ -54,11 +54,11 @@ class FasterCacheConfig:
Attributes:
spatial_attention_block_skip_range (`int`, defaults to `2`):
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
- be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
+ be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
- be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
+ be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
states again.
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
The timestep range within which the spatial attention computation can be skipped without a significant loss
@@ -90,7 +90,7 @@ class FasterCacheConfig:
from the conditional branch outputs.
unconditional_batch_skip_range (`int`, defaults to `5`):
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
- computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
+ computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
computing the new unconditional branch states again.
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
The timestep range within which the unconditional branch computation can be skipped without a significant
@@ -146,7 +146,7 @@ class FasterCacheConfig:
alpha_low_frequency: float = 1.1
alpha_high_frequency: float = 1.1
- # n as described in CFG-Cache explanation in the paper - dependant on the model
+ # n as described in CFG-Cache explanation in the paper - dependent on the model
unconditional_batch_skip_range: int = 5
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
@@ -488,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
Args:
- pipeline (`DiffusionPipeline`):
- The diffusion pipeline to apply FasterCache to.
- config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
+ module (`torch.nn.Module`):
+ The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
+ in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
+ config (`FasterCacheConfig`):
The configuration to use for FasterCache.
Example:
@@ -588,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
-def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
+def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py
new file mode 100644
index 000000000000..862d44059301
--- /dev/null
+++ b/src/diffusers/hooks/first_block_cache.py
@@ -0,0 +1,259 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Tuple, Union
+
+import torch
+
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
+from ._helpers import TransformerBlockRegistry
+from .hooks import BaseState, HookRegistry, ModelHook, StateManager
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
+_FBC_BLOCK_HOOK = "fbc_block_hook"
+
+
+@dataclass
+class FirstBlockCacheConfig:
+ r"""
+ Configuration for [First Block
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
+
+ Args:
+ threshold (`float`, defaults to `0.05`):
+ The threshold to determine whether or not a forward pass through all layers of the model is required. A
+ higher threshold usually results in a forward pass through a lower number of layers and faster inference,
+ but might lead to poorer generation quality. A lower threshold may not result in significant generation
+ speedup. The threshold is compared against the absmean difference of the residuals between the current and
+ cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
+ is skipped.
+ """
+
+ threshold: float = 0.05
+
+
+class FBCSharedBlockState(BaseState):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
+ self.head_block_residual: torch.Tensor = None
+ self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
+ self.should_compute: bool = True
+
+ def reset(self):
+ self.tail_block_residuals = None
+ self.should_compute = True
+
+
+class FBCHeadBlockHook(ModelHook):
+ _is_stateful = True
+
+ def __init__(self, state_manager: StateManager, threshold: float):
+ self.state_manager = state_manager
+ self.threshold = threshold
+ self._metadata = None
+
+ def initialize_hook(self, module):
+ unwrapped_module = unwrap_module(module)
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ is_output_tuple = isinstance(output, tuple)
+
+ if is_output_tuple:
+ hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
+ else:
+ hidden_states_residual = output - original_hidden_states
+
+ shared_state: FBCSharedBlockState = self.state_manager.get_state()
+ hidden_states = encoder_hidden_states = None
+ should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
+ shared_state.should_compute = should_compute
+
+ if not should_compute:
+ # Apply caching
+ if is_output_tuple:
+ hidden_states = (
+ shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
+ )
+ else:
+ hidden_states = shared_state.tail_block_residuals[0] + output
+
+ if self._metadata.return_encoder_hidden_states_index is not None:
+ assert is_output_tuple
+ encoder_hidden_states = (
+ shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
+ )
+
+ if is_output_tuple:
+ return_output = [None] * len(output)
+ return_output[self._metadata.return_hidden_states_index] = hidden_states
+ return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
+ return_output = tuple(return_output)
+ else:
+ return_output = hidden_states
+ output = return_output
+ else:
+ if is_output_tuple:
+ head_block_output = [None] * len(output)
+ head_block_output[0] = output[self._metadata.return_hidden_states_index]
+ head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
+ else:
+ head_block_output = output
+ shared_state.head_block_output = head_block_output
+ shared_state.head_block_residual = hidden_states_residual
+
+ return output
+
+ def reset_state(self, module):
+ self.state_manager.reset()
+ return module
+
+ @torch.compiler.disable
+ def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
+ shared_state = self.state_manager.get_state()
+ if shared_state.head_block_residual is None:
+ return True
+ prev_hidden_states_residual = shared_state.head_block_residual
+ absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
+ prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
+ diff = (absmean / prev_hidden_states_absmean).item()
+ return diff > self.threshold
+
+
+class FBCBlockHook(ModelHook):
+ def __init__(self, state_manager: StateManager, is_tail: bool = False):
+ super().__init__()
+ self.state_manager = state_manager
+ self.is_tail = is_tail
+ self._metadata = None
+
+ def initialize_hook(self, module):
+ unwrapped_module = unwrap_module(module)
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ original_encoder_hidden_states = None
+ if self._metadata.return_encoder_hidden_states_index is not None:
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
+ "encoder_hidden_states", args, kwargs
+ )
+
+ shared_state = self.state_manager.get_state()
+
+ if shared_state.should_compute:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ if self.is_tail:
+ hidden_states_residual = encoder_hidden_states_residual = None
+ if isinstance(output, tuple):
+ hidden_states_residual = (
+ output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
+ )
+ encoder_hidden_states_residual = (
+ output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
+ )
+ else:
+ hidden_states_residual = output - shared_state.head_block_output
+ shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
+ return output
+
+ if original_encoder_hidden_states is None:
+ return_output = original_hidden_states
+ else:
+ return_output = [None, None]
+ return_output[self._metadata.return_hidden_states_index] = original_hidden_states
+ return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
+ return_output = tuple(return_output)
+ return return_output
+
+
+def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
+ """
+ Applies [First Block
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
+ to a given module.
+
+ First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
+ to implement generically for a wide range of models and has been integrated first for experimental purposes.
+
+ Args:
+ module (`torch.nn.Module`):
+ The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
+ Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
+ config (`FirstBlockCacheConfig`):
+ The configuration to use for applying the FBCache method.
+
+ Example:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogView4Pipeline
+ >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
+
+ >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
+
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
+ >>> image.save("output.png")
+ ```
+ """
+
+ state_manager = StateManager(FBCSharedBlockState, (), {})
+ remaining_blocks = []
+
+ for name, submodule in module.named_children():
+ if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
+ continue
+ for index, block in enumerate(submodule):
+ remaining_blocks.append((f"{name}.{index}", block))
+
+ head_block_name, head_block = remaining_blocks.pop(0)
+ tail_block_name, tail_block = remaining_blocks.pop(-1)
+
+ logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
+ _apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
+
+ for name, block in remaining_blocks:
+ logger.debug(f"Applying FBCBlockHook to '{name}'")
+ _apply_fbc_block_hook(block, state_manager)
+
+ logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
+ _apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
+
+
+def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = FBCHeadBlockHook(state_manager, threshold)
+ registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
+
+
+def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = FBCBlockHook(state_manager, is_tail)
+ registry.register_hook(hook, _FBC_BLOCK_HOOK)
diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py
index 4c1d354a0f59..47f1f4199615 100644
--- a/src/diffusers/hooks/group_offloading.py
+++ b/src/diffusers/hooks/group_offloading.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,12 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
+import os
from contextlib import contextmanager, nullcontext
-from typing import Dict, List, Optional, Set, Tuple
+from dataclasses import dataclass, replace
+from enum import Enum
+from typing import Dict, List, Optional, Set, Tuple, Union
+import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
+from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -33,17 +39,31 @@
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
-
-_SUPPORTED_PYTORCH_LAYERS = (
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
- torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
- torch.nn.Linear,
- # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
- # because of double invocation of the same norm layer in CogVideoXLayerNorm
-)
+_GROUP_ID_LAZY_LEAF = "lazy_leafs"
# fmt: on
+class GroupOffloadingType(str, Enum):
+ BLOCK_LEVEL = "block_level"
+ LEAF_LEVEL = "leaf_level"
+
+
+@dataclass
+class GroupOffloadingConfig:
+ onload_device: torch.device
+ offload_device: torch.device
+ offload_type: GroupOffloadingType
+ non_blocking: bool
+ record_stream: bool
+ low_cpu_mem_usage: bool
+ num_blocks_per_group: Optional[int] = None
+ offload_to_disk_path: Optional[str] = None
+ stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
+ block_modules: Optional[List[str]] = None
+ exclude_kwargs: Optional[List[str]] = None
+ module_prefix: Optional[str] = ""
+
+
class ModuleGroup:
def __init__(
self,
@@ -55,9 +75,12 @@ def __init__(
parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
- stream: Optional[torch.cuda.Stream] = None,
- low_cpu_mem_usage=False,
+ stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
+ record_stream: Optional[bool] = False,
+ low_cpu_mem_usage: bool = False,
onload_self: bool = True,
+ offload_to_disk_path: Optional[str] = None,
+ group_id: Optional[Union[int, str]] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -68,10 +91,38 @@ def __init__(
self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None
self.stream = stream
+ self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
- self.cpu_param_dict = self._init_cpu_param_dict()
+ self.offload_to_disk_path = offload_to_disk_path
+ self._is_offloaded_to_disk = False
+
+ if self.offload_to_disk_path is not None:
+ # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
+ self.group_id = group_id if group_id is not None else str(id(self))
+ short_hash = _compute_group_hash(self.group_id)
+ self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
+
+ all_tensors = []
+ for module in self.modules:
+ all_tensors.extend(list(module.parameters()))
+ all_tensors.extend(list(module.buffers()))
+ all_tensors.extend(self.parameters)
+ all_tensors.extend(self.buffers)
+ all_tensors = list(dict.fromkeys(all_tensors)) # Remove duplicates
+
+ self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
+ self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
+ self.cpu_param_dict = {}
+ else:
+ self.cpu_param_dict = self._init_cpu_param_dict()
+
+ self._torch_accelerator_module = (
+ getattr(torch, torch.accelerator.current_accelerator().type)
+ if hasattr(torch, "accelerator")
+ else torch.cuda
+ )
def _init_cpu_param_dict(self):
cpu_param_dict = {}
@@ -96,58 +147,102 @@ def _init_cpu_param_dict(self):
@contextmanager
def _pinned_memory_tensors(self):
- pinned_dict = {}
try:
- for param, tensor in self.cpu_param_dict.items():
- if not tensor.is_pinned():
- pinned_dict[param] = tensor.pin_memory()
- else:
- pinned_dict[param] = tensor
-
+ pinned_dict = {
+ param: tensor.pin_memory() if not tensor.is_pinned() else tensor
+ for param, tensor in self.cpu_param_dict.items()
+ }
yield pinned_dict
-
finally:
pinned_dict = None
- def onload_(self):
- r"""Onloads the group of modules to the onload_device."""
- context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
+ def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
+ tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
+ if self.record_stream:
+ tensor.data.record_stream(default_stream)
+
+ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
+ for group_module in self.modules:
+ for param in group_module.parameters():
+ source = pinned_memory[param] if pinned_memory else param.data
+ self._transfer_tensor_to_device(param, source, default_stream)
+ for buffer in group_module.buffers():
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
+ self._transfer_tensor_to_device(buffer, source, default_stream)
+
+ for param in self.parameters:
+ source = pinned_memory[param] if pinned_memory else param.data
+ self._transfer_tensor_to_device(param, source, default_stream)
+
+ for buffer in self.buffers:
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
+ self._transfer_tensor_to_device(buffer, source, default_stream)
+
+ def _onload_from_disk(self):
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
- with context:
- if self.stream is not None:
- with self._pinned_memory_tensors() as pinned_memory:
- for group_module in self.modules:
- for param in group_module.parameters():
- param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
- for buffer in group_module.buffers():
- buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
-
- for param in self.parameters:
- param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
+ current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
- for buffer in self.buffers:
- buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
+ with context:
+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
+ device = str(self.onload_device) if self.stream is None else "cpu"
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
+ if self.stream is not None:
+ for key, tensor_obj in self.key_to_tensor.items():
+ pinned_tensor = loaded_tensors[key].pin_memory()
+ tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
+ if self.record_stream:
+ tensor_obj.data.record_stream(current_stream)
else:
- for group_module in self.modules:
- for param in group_module.parameters():
- param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
- for buffer in group_module.buffers():
- buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
+ onload_device = (
+ self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
+ )
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
+ for key, tensor_obj in self.key_to_tensor.items():
+ tensor_obj.data = loaded_tensors[key]
- for param in self.parameters:
- param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
+ def _onload_from_memory(self):
+ if self.stream is not None:
+ # Wait for previous Host->Device transfer to complete
+ self.stream.synchronize()
- for buffer in self.buffers:
- buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
+ default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None
- def offload_(self):
- r"""Offloads the group of modules to the offload_device."""
+ with context:
+ if self.stream is not None:
+ with self._pinned_memory_tensors() as pinned_memory:
+ self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
+ else:
+ self._process_tensors_from_modules(None)
+
+ def _offload_to_disk(self):
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
+ # we perform a write.
+ # Check if the file has been saved in this session or if it already exists on disk.
+ if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
+ os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
+ tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
+ safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
+
+ # The group is now considered offloaded to disk for the rest of the session.
+ self._is_offloaded_to_disk = True
+
+ # We do this to free up the RAM which is still holding the up tensor data.
+ for tensor_obj in self.tensor_to_key.keys():
+ tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
+
+ def _offload_to_memory(self):
if self.stream is not None:
- torch.cuda.current_stream().synchronize()
+ if not self.record_stream:
+ self._torch_accelerator_module.current_stream().synchronize()
+
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -155,14 +250,29 @@ def offload_(self):
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
-
else:
for group_module in self.modules:
- group_module.to(self.offload_device, non_blocking=self.non_blocking)
+ group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
- param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
+ param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
- buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
+ buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
+
+ @torch.compiler.disable()
+ def onload_(self):
+ r"""Onloads the group of parameters to the onload_device."""
+ if self.offload_to_disk_path is not None:
+ self._onload_from_disk()
+ else:
+ self._onload_from_memory()
+
+ @torch.compiler.disable()
+ def offload_(self):
+ r"""Offloads the group of parameters to the offload_device."""
+ if self.offload_to_disk_path:
+ self._offload_to_disk()
+ else:
+ self._offload_to_memory()
class GroupOffloadingHook(ModelHook):
@@ -175,13 +285,10 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
- def __init__(
- self,
- group: ModuleGroup,
- next_group: Optional[ModuleGroup] = None,
- ) -> None:
+ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
self.group = group
- self.next_group = next_group
+ self.next_group: Optional[ModuleGroup] = None
+ self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
@@ -200,11 +307,39 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
- if self.next_group is not None and not self.next_group.onload_self:
+
+ should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
+ if should_onload_next_group:
self.next_group.onload_()
+ should_synchronize = (
+ not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
+ )
+ if should_synchronize:
+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
+ # previous group. We need to synchronize the side stream to ensure parameters
+ # are completely loaded to proceed with forward pass. Without this, uninitialized
+ # weights will be used in the computation, leading to incorrect results
+ # Also, we should only do this synchronization if we don't already do it from the sync call in
+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
+ self.group.stream.synchronize()
+
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
- kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
+
+ # Some Autoencoder models use a feature cache that is passed through submodules
+ # and modified in place. The `send_to_device` call returns a copy of this feature cache object
+ # which breaks the inplace updates. Use `exclude_kwargs` to mark these cache features
+ exclude_kwargs = self.config.exclude_kwargs or []
+ if exclude_kwargs:
+ moved_kwargs = send_to_device(
+ {k: v for k, v in kwargs.items() if k not in exclude_kwargs},
+ self.group.onload_device,
+ non_blocking=self.group.non_blocking,
+ )
+ kwargs.update(moved_kwargs)
+ else:
+ kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
+
return args, kwargs
def post_forward(self, module: torch.nn.Module, output):
@@ -215,7 +350,7 @@ def post_forward(self, module: torch.nn.Module, output):
class LazyPrefetchGroupOffloadingHook(ModelHook):
r"""
- A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
+ A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
prefetching groups in the correct order.
@@ -230,7 +365,8 @@ def __init__(self):
def initialize_hook(self, module):
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
- logger.debug(f"Adding {current_name} to the execution order")
+ if not torch.compiler.is_compiling():
+ logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
@@ -267,12 +403,13 @@ def post_forward(self, module, output):
# if the missing layers end up being executed in the future.
if execution_order_module_names != self._layer_execution_tracker_module_names:
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
- logger.warning(
- "It seems like some layers were not executed during the forward pass. This may lead to problems when "
- "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
- "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
- f"{unexecuted_layers=}"
- )
+ if not torch.compiler.is_compiling():
+ logger.warning(
+ "It seems like some layers were not executed during the forward pass. This may lead to problems when "
+ "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
+ "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
+ f"{unexecuted_layers=}"
+ )
# Remove the layer execution tracker hooks from the submodules
base_module_registry = module._diffusers_hook
@@ -300,7 +437,8 @@ def post_forward(self, module, output):
for i in range(num_executed - 1):
name1, _ = self.execution_order[i]
name2, _ = self.execution_order[i + 1]
- logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
+ if not torch.compiler.is_compiling():
+ logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
group_offloading_hooks[i].next_group.onload_self = False
@@ -325,13 +463,17 @@ def pre_forward(self, module, *args, **kwargs):
def apply_group_offloading(
module: torch.nn.Module,
- onload_device: torch.device,
- offload_device: torch.device = torch.device("cpu"),
- offload_type: str = "block_level",
+ onload_device: Union[str, torch.device],
+ offload_device: Union[str, torch.device] = torch.device("cpu"),
+ offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
+ record_stream: bool = False,
low_cpu_mem_usage: bool = False,
+ offload_to_disk_path: Optional[str] = None,
+ block_modules: Optional[List[str]] = None,
+ exclude_kwargs: Optional[List[str]] = None,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -367,9 +509,12 @@ def apply_group_offloading(
The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
- offload_type (`str`, defaults to "block_level"):
+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
+ RAM environment settings where a reasonable speed-memory trade-off is desired.
num_blocks_per_group (`int`, *optional*):
The number of blocks per group when using offload_type="block_level". This is required when using
offload_type="block_level".
@@ -378,10 +523,21 @@ def apply_group_offloading(
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
+ [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
+ details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
+ block_modules (`List[str]`, *optional*):
+ List of module names that should be treated as blocks for offloading. If provided, only these modules will
+ be considered for block-level offloading. If not provided, the default block detection logic will be used.
+ exclude_kwargs (`List[str]`, *optional*):
+ List of kwarg keys that should not be processed by send_to_device. This is useful for mutable state like
+ caching lists that need to maintain their object identity across forward passes. If not provided, will be
+ inferred from the module's `_skip_keys` attribute if it exists.
Example:
```python
@@ -403,93 +559,124 @@ def apply_group_offloading(
```
"""
+ onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
+ offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
+ offload_type = GroupOffloadingType(offload_type)
+
stream = None
if use_stream:
if torch.cuda.is_available():
stream = torch.cuda.Stream()
+ elif hasattr(torch, "xpu") and torch.xpu.is_available():
+ stream = torch.Stream()
else:
- raise ValueError("Using streams for data transfer requires a CUDA device.")
+ raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
+
+ if not use_stream and record_stream:
+ raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
+ if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
+ raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
- if offload_type == "block_level":
- if num_blocks_per_group is None:
- raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
+ if block_modules is None:
+ block_modules = getattr(module, "_group_offload_block_modules", None)
- _apply_group_offloading_block_level(
- module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
- )
- elif offload_type == "leaf_level":
- _apply_group_offloading_leaf_level(
- module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
- )
+ if exclude_kwargs is None:
+ exclude_kwargs = getattr(module, "_skip_keys", None)
+
+ config = GroupOffloadingConfig(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type=offload_type,
+ num_blocks_per_group=num_blocks_per_group,
+ non_blocking=non_blocking,
+ stream=stream,
+ record_stream=record_stream,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ offload_to_disk_path=offload_to_disk_path,
+ block_modules=block_modules,
+ exclude_kwargs=exclude_kwargs,
+ )
+ _apply_group_offloading(module, config)
+
+
+def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
+ if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
+ _apply_group_offloading_block_level(module, config)
+ elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
+ _apply_group_offloading_leaf_level(module, config)
else:
- raise ValueError(f"Unsupported offload_type: {offload_type}")
+ assert False
-def _apply_group_offloading_block_level(
- module: torch.nn.Module,
- num_blocks_per_group: int,
- offload_device: torch.device,
- onload_device: torch.device,
- non_blocking: bool,
- stream: Optional[torch.cuda.Stream] = None,
- low_cpu_mem_usage: bool = False,
-) -> None:
+def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
- This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
- the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
+ This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
+ defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading is
+ done at the top-level blocks and modules specified in block_modules.
- Args:
- module (`torch.nn.Module`):
- The module to which group offloading is applied.
- offload_device (`torch.device`):
- The device to which the group of modules are offloaded. This should typically be the CPU.
- onload_device (`torch.device`):
- The device to which the group of modules are onloaded.
- non_blocking (`bool`):
- If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
- and data transfer.
- stream (`torch.cuda.Stream`, *optional*):
- If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
- for overlapping computation and data transfer.
+ When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
+ module, recursively apply block offloading to it.
"""
+ if config.stream is not None and config.num_blocks_per_group != 1:
+ logger.warning(
+ f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
+ )
+ config.num_blocks_per_group = 1
+
+ block_modules = set(config.block_modules) if config.block_modules is not None else set()
- # Create module groups for ModuleList and Sequential blocks
+ # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set()
unmatched_modules = []
matched_module_groups = []
+
for name, submodule in module.named_children():
- if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
- unmatched_modules.append((name, submodule))
+ # Check if this is an explicitly defined block module
+ if name in block_modules:
+ # Track submodule using a prefix to avoid filename collisions during disk offload.
+ # Without this, submodules sharing the same model class would be assigned identical
+ # filenames (derived from the class name).
+ prefix = f"{config.module_prefix}{name}." if config.module_prefix else f"{name}."
+ submodule_config = replace(config, module_prefix=prefix)
+
+ _apply_group_offloading_block_level(submodule, submodule_config)
modules_with_group_offloading.add(name)
- continue
- for i in range(0, len(submodule), num_blocks_per_group):
- current_modules = submodule[i : i + num_blocks_per_group]
- group = ModuleGroup(
- modules=current_modules,
- offload_device=offload_device,
- onload_device=onload_device,
- offload_leader=current_modules[-1],
- onload_leader=current_modules[0],
- non_blocking=non_blocking,
- stream=stream,
- low_cpu_mem_usage=low_cpu_mem_usage,
- onload_self=stream is None,
- )
- matched_module_groups.append(group)
- for j in range(i, i + len(current_modules)):
- modules_with_group_offloading.add(f"{name}.{j}")
+ elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
+ # Handle ModuleList and Sequential blocks as before
+ for i in range(0, len(submodule), config.num_blocks_per_group):
+ current_modules = list(submodule[i : i + config.num_blocks_per_group])
+ if len(current_modules) == 0:
+ continue
+
+ group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
+ group = ModuleGroup(
+ modules=current_modules,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
+ offload_to_disk_path=config.offload_to_disk_path,
+ offload_leader=current_modules[-1],
+ onload_leader=current_modules[0],
+ non_blocking=config.non_blocking,
+ stream=config.stream,
+ record_stream=config.record_stream,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
+ onload_self=True,
+ group_id=group_id,
+ )
+ matched_module_groups.append(group)
+ for j in range(i, i + len(current_modules)):
+ modules_with_group_offloading.add(f"{name}.{j}")
+ else:
+ # This is an unmatched module
+ unmatched_modules.append((name, submodule))
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
- next_group = (
- matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
- )
-
for group_module in group.modules:
- _apply_group_offloading_hook(group_module, group, next_group)
+ _apply_group_offloading_hook(group_module, group, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -499,71 +686,58 @@ def _apply_group_offloading_block_level(
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]
- # Create a group for the unmatched submodules of the top-level module so that they are on the correct
- # device when the forward pass is called.
+ # Create a group for the remaining unmatched submodules of the top-level
+ # module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
- unmatched_group = ModuleGroup(
- modules=unmatched_modules,
- offload_device=offload_device,
- onload_device=onload_device,
- offload_leader=module,
- onload_leader=module,
- parameters=parameters,
- buffers=buffers,
- non_blocking=False,
- stream=None,
- onload_self=True,
- )
- next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
- _apply_group_offloading_hook(module, unmatched_group, next_group)
+ if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
+ unmatched_group = ModuleGroup(
+ modules=unmatched_modules,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
+ offload_to_disk_path=config.offload_to_disk_path,
+ offload_leader=module,
+ onload_leader=module,
+ parameters=parameters,
+ buffers=buffers,
+ non_blocking=False,
+ stream=None,
+ record_stream=False,
+ onload_self=True,
+ group_id=f"{config.module_prefix}{module.__class__.__name__}_unmatched_group",
+ )
+ if config.stream is None:
+ _apply_group_offloading_hook(module, unmatched_group, config=config)
+ else:
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
-def _apply_group_offloading_leaf_level(
- module: torch.nn.Module,
- offload_device: torch.device,
- onload_device: torch.device,
- non_blocking: bool,
- stream: Optional[torch.cuda.Stream] = None,
- low_cpu_mem_usage: bool = False,
-) -> None:
+def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
reduce memory usage without any performance degradation.
-
- Args:
- module (`torch.nn.Module`):
- The module to which group offloading is applied.
- offload_device (`torch.device`):
- The device to which the group of modules are offloaded. This should typically be the CPU.
- onload_device (`torch.device`):
- The device to which the group of modules are onloaded.
- non_blocking (`bool`):
- If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
- and data transfer.
- stream (`torch.cuda.Stream`, *optional*):
- If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
- for overlapping computation and data transfer.
"""
-
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
- if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
group = ModuleGroup(
modules=[submodule],
- offload_device=offload_device,
- onload_device=onload_device,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
+ offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
- non_blocking=non_blocking,
- stream=stream,
- low_cpu_mem_usage=low_cpu_mem_usage,
+ non_blocking=config.non_blocking,
+ stream=config.stream,
+ record_stream=config.record_stream,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
+ group_id=name,
)
- _apply_group_offloading_hook(submodule, group, None)
+ _apply_group_offloading_hook(submodule, group, config=config)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -594,67 +768,74 @@ def _apply_group_offloading_leaf_level(
parameters = parent_to_parameters.get(name, [])
buffers = parent_to_buffers.get(name, [])
parent_module = module_dict[name]
- assert getattr(parent_module, "_diffusers_hook", None) is None
group = ModuleGroup(
modules=[],
- offload_device=offload_device,
- onload_device=onload_device,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
+ offload_to_disk_path=config.offload_to_disk_path,
parameters=parameters,
buffers=buffers,
- non_blocking=non_blocking,
- stream=stream,
- low_cpu_mem_usage=low_cpu_mem_usage,
+ non_blocking=config.non_blocking,
+ stream=config.stream,
+ record_stream=config.record_stream,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
+ group_id=name,
)
- _apply_group_offloading_hook(parent_module, group, None)
+ _apply_group_offloading_hook(parent_module, group, config=config)
- if stream is not None:
+ if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
# execution order and apply prefetching in the correct order.
unmatched_group = ModuleGroup(
modules=[],
- offload_device=offload_device,
- onload_device=onload_device,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
+ offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=None,
buffers=None,
non_blocking=False,
stream=None,
- low_cpu_mem_usage=low_cpu_mem_usage,
+ record_stream=False,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
+ group_id=_GROUP_ID_LAZY_LEAF,
)
- _apply_lazy_group_offloading_hook(module, unmatched_group, None)
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
- next_group: Optional[ModuleGroup] = None,
+ *,
+ config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
- hook = GroupOffloadingHook(group, next_group)
+ hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
- next_group: Optional[ModuleGroup] = None,
+ *,
+ config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
- hook = GroupOffloadingHook(group, next_group)
+ hook = GroupOffloadingHook(group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
@@ -721,15 +902,54 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
)
-def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
+def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
for submodule in module.modules():
- if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
- return True
- return False
+ if hasattr(submodule, "_diffusers_hook"):
+ group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
+ if group_offloading_hook is not None:
+ return group_offloading_hook
+ return None
+
+
+def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
+ return top_level_group_offload_hook is not None
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
- for submodule in module.modules():
- if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
- return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
+ if top_level_group_offload_hook is not None:
+ return top_level_group_offload_hook.config.onload_device
raise ValueError("Group offloading is not enabled for the provided module.")
+
+
+def _compute_group_hash(group_id):
+ hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
+ # first 16 characters for a reasonably short but unique name
+ return hashed_id[:16]
+
+
+def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
+ r"""
+ Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
+ modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
+ modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
+
+ In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
+ and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
+ case where user has applied group offloading at multiple levels, this function will not work as expected.
+
+ There is some performance penalty associated with doing this when non-default streams are used, because we need to
+ retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
+ """
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
+
+ if top_level_group_offload_hook is None:
+ return
+
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
+ registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
+ registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
+
+ _apply_group_offloading(module, top_level_group_offload_hook.config)
diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py
index 3b2e4ed91c2f..6e097e5882a0 100644
--- a/src/diffusers/hooks/hooks.py
+++ b/src/diffusers/hooks/hooks.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,44 @@
import torch
from ..utils.logging import get_logger
+from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
+class BaseState:
+ def reset(self, *args, **kwargs) -> None:
+ raise NotImplementedError(
+ "BaseState::reset is not implemented. Please implement this method in the derived class."
+ )
+
+
+class StateManager:
+ def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
+ self._state_cls = state_cls
+ self._init_args = init_args if init_args is not None else ()
+ self._init_kwargs = init_kwargs if init_kwargs is not None else {}
+ self._state_cache = {}
+ self._current_context = None
+
+ def get_state(self):
+ if self._current_context is None:
+ raise ValueError("No context is set. Please set a context before retrieving the state.")
+ if self._current_context not in self._state_cache.keys():
+ self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
+ return self._state_cache[self._current_context]
+
+ def set_context(self, name: str) -> None:
+ self._current_context = name
+
+ def reset(self, *args, **kwargs) -> None:
+ for name, state in list(self._state_cache.items()):
+ state.reset(*args, **kwargs)
+ self._state_cache.pop(name)
+ self._current_context = None
+
+
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -45,7 +78,7 @@ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
- Hook that is executed when a model is deinitalized.
+ Hook that is executed when a model is deinitialized.
Args:
module (`torch.nn.Module`):
@@ -99,6 +132,14 @@ def reset_state(self, module: torch.nn.Module):
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
+ def _set_context(self, module: torch.nn.Module, name: str) -> None:
+ # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
+ for attr_name in dir(self):
+ attr = getattr(self, attr_name)
+ if isinstance(attr, StateManager):
+ attr.set_context(name)
+ return module
+
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +252,10 @@ def reset_stateful_hooks(self, recurse: bool = True) -> None:
hook.reset_state(self._module_ref)
if recurse:
- for module_name, module in self._module_ref.named_modules():
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
+ module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +265,19 @@ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry
module._diffusers_hook = cls(module)
return module._diffusers_hook
+ def _set_context(self, name: Optional[str] = None) -> None:
+ for hook_name in reversed(self._hook_order):
+ hook = self.hooks[hook_name]
+ if hook._is_stateful:
+ hook._set_context(self._module_ref, name)
+
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
+ if module_name == "":
+ continue
+ module = unwrap_module(module)
+ if hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook._set_context(name)
+
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py
new file mode 100644
index 000000000000..0ce02e987d09
--- /dev/null
+++ b/src/diffusers/hooks/layer_skip.py
@@ -0,0 +1,263 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import asdict, dataclass
+from typing import Callable, List, Optional
+
+import torch
+
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from ._common import (
+ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ _ATTENTION_CLASSES,
+ _FEEDFORWARD_CLASSES,
+ _get_submodule_from_fqn,
+)
+from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_LAYER_SKIP_HOOK = "layer_skip_hook"
+
+
+# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
+# either remove or make it serializable
+@dataclass
+class LayerSkipConfig:
+ r"""
+ Configuration for skipping internal transformer blocks when executing a transformer model.
+
+ Args:
+ indices (`List[int]`):
+ The indices of the layer to skip. This is typically the first layer in the transformer block.
+ fqn (`str`, defaults to `"auto"`):
+ The fully qualified name identifying the stack of transformer blocks. Typically, this is
+ `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
+ For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
+ provide the correct fqn.
+ skip_attention (`bool`, defaults to `True`):
+ Whether to skip attention blocks.
+ skip_ff (`bool`, defaults to `True`):
+ Whether to skip feed-forward blocks.
+ skip_attention_scores (`bool`, defaults to `False`):
+ Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
+ projections as the output of scaled dot product attention.
+ dropout (`float`, defaults to `1.0`):
+ The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
+ meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
+ skipped layers are fully retained, which is equivalent to not skipping any layers.
+ """
+
+ indices: List[int]
+ fqn: str = "auto"
+ skip_attention: bool = True
+ skip_attention_scores: bool = False
+ skip_ff: bool = True
+ dropout: float = 1.0
+
+ def __post_init__(self):
+ if not (0 <= self.dropout <= 1):
+ raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
+ if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
+ raise ValueError(
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
+ )
+
+ def to_dict(self):
+ return asdict(self)
+
+ @staticmethod
+ def from_dict(data: dict) -> "LayerSkipConfig":
+ return LayerSkipConfig(**data)
+
+
+class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
+ def __torch_function__(self, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ if func is torch.nn.functional.scaled_dot_product_attention:
+ query = kwargs.get("query", None)
+ key = kwargs.get("key", None)
+ value = kwargs.get("value", None)
+ query = query if query is not None else args[0]
+ key = key if key is not None else args[1]
+ value = value if value is not None else args[2]
+ # If the Q sequence length does not match KV sequence length, methods like
+ # Perturbed Attention Guidance cannot be used (because the caller expects
+ # the same sequence length as Q, but if we return V here, it will not match).
+ # When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
+ # the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
+ if query.shape[2] == value.shape[2]:
+ return value
+ return func(*args, **kwargs)
+
+
+class AttentionProcessorSkipHook(ModelHook):
+ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
+ self.skip_processor_output_fn = skip_processor_output_fn
+ self.skip_attention_scores = skip_attention_scores
+ self.dropout = dropout
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ if self.skip_attention_scores:
+ if not math.isclose(self.dropout, 1.0):
+ raise ValueError(
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
+ )
+ with AttentionScoreSkipFunctionMode():
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ else:
+ if math.isclose(self.dropout, 1.0):
+ output = self.skip_processor_output_fn(module, *args, **kwargs)
+ else:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ output = torch.nn.functional.dropout(output, p=self.dropout)
+ return output
+
+
+class FeedForwardSkipHook(ModelHook):
+ def __init__(self, dropout: float):
+ super().__init__()
+ self.dropout = dropout
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ if math.isclose(self.dropout, 1.0):
+ output = kwargs.get("hidden_states", None)
+ if output is None:
+ output = kwargs.get("x", None)
+ if output is None and len(args) > 0:
+ output = args[0]
+ else:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ output = torch.nn.functional.dropout(output, p=self.dropout)
+ return output
+
+
+class TransformerBlockSkipHook(ModelHook):
+ def __init__(self, dropout: float):
+ super().__init__()
+ self.dropout = dropout
+
+ def initialize_hook(self, module):
+ self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ if math.isclose(self.dropout, 1.0):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ if self._metadata.return_encoder_hidden_states_index is None:
+ output = original_hidden_states
+ else:
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
+ "encoder_hidden_states", args, kwargs
+ )
+ output = (original_hidden_states, original_encoder_hidden_states)
+ else:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ output = torch.nn.functional.dropout(output, p=self.dropout)
+ return output
+
+
+def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
+ r"""
+ Apply layer skipping to internal layers of a transformer.
+
+ Args:
+ module (`torch.nn.Module`):
+ The transformer model to which the layer skip hook should be applied.
+ config (`LayerSkipConfig`):
+ The configuration for the layer skip hook.
+
+ Example:
+
+ ```python
+ >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
+
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+ >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
+ >>> apply_layer_skip_hook(transformer, config)
+ ```
+ """
+ _apply_layer_skip_hook(module, config)
+
+
+def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
+ name = name or _LAYER_SKIP_HOOK
+
+ if config.skip_attention and config.skip_attention_scores:
+ raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
+ if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
+ raise ValueError(
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
+ )
+
+ if config.fqn == "auto":
+ for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
+ if hasattr(module, identifier):
+ config.fqn = identifier
+ break
+ else:
+ raise ValueError(
+ "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
+ "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
+ )
+
+ transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
+ if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
+ raise ValueError(
+ f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
+ f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
+ )
+ if len(config.indices) == 0:
+ raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
+
+ blocks_found = False
+ for i, block in enumerate(transformer_blocks):
+ if i not in config.indices:
+ continue
+
+ blocks_found = True
+
+ if config.skip_attention and config.skip_ff:
+ logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = TransformerBlockSkipHook(config.dropout)
+ registry.register_hook(hook, name)
+
+ elif config.skip_attention or config.skip_attention_scores:
+ for submodule_name, submodule in block.named_modules():
+ if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
+ logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
+ output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
+ hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
+ registry.register_hook(hook, name)
+
+ if config.skip_ff:
+ for submodule_name, submodule in block.named_modules():
+ if isinstance(submodule, _FEEDFORWARD_CLASSES):
+ logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
+ hook = FeedForwardSkipHook(config.dropout)
+ registry.register_hook(hook, name)
+
+ if not blocks_found:
+ raise ValueError(
+ f"Could not find any transformer blocks matching the provided indices {config.indices} and "
+ f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
+ )
diff --git a/src/diffusers/hooks/layerwise_casting.py b/src/diffusers/hooks/layerwise_casting.py
index 6f2cfdc3485a..a036ad37dc2f 100644
--- a/src/diffusers/hooks/layerwise_casting.py
+++ b/src/diffusers/hooks/layerwise_casting.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
import torch
from ..utils import get_logger, is_peft_available, is_peft_version
+from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -27,12 +28,6 @@
# fmt: off
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
-SUPPORTED_PYTORCH_LAYERS = (
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
- torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
- torch.nn.Linear,
-)
-
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on
@@ -62,7 +57,7 @@ def initialize_hook(self, module: torch.nn.Module):
def deinitalize_hook(self, module: torch.nn.Module):
raise NotImplementedError(
- "LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will "
+ "LayerwiseCastingHook does not support deinitialization. A model once enabled with layerwise casting will "
"have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
"will lead to precision loss, which might have an impact on the model's generation quality. The model should "
"be re-initialized and loaded in the original dtype."
@@ -90,7 +85,7 @@ class PeftInputAutocastDisableHook(ModelHook):
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
hoping to achieve:
1. Making forward implementations independent of device/dtype casting operations as much as possible.
- 2. Peforming inference without losing information from casting to different precisions. With the current
+ 2. Performing inference without losing information from casting to different precisions. With the current
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
@@ -186,7 +181,7 @@ def _apply_layerwise_casting(
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
return
- if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
+ if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py
index 5d50f4b816c1..12d6aa0616e9 100644
--- a/src/diffusers/hooks/pyramid_attention_broadcast.py
+++ b/src/diffusers/hooks/pyramid_attention_broadcast.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +18,15 @@
import torch
+from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
+from ._common import (
+ _ATTENTION_CLASSES,
+ _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
+ _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+)
from .hooks import HookRegistry, ModelHook
@@ -27,10 +34,6 @@
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
-_ATTENTION_CLASSES = (Attention, MochiAttention)
-_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
-_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
-_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
@@ -42,15 +45,15 @@ class PyramidAttentionBroadcastConfig:
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific spatial attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
- old attention states will be re-used) before computing the new attention states again.
+ old attention states will be reused) before computing the new attention states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific temporal attention broadcast is skipped before computing the attention
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
- (i.e., old attention states will be re-used) before computing the new attention states again.
+ (i.e., old attention states will be reused) before computing the new attention states again.
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific cross-attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
- old attention states will be re-used) before computing the new attention states again.
+ old attention states will be reused) before computing the new attention states again.
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the spatial attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
@@ -60,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
- spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
+ spatial_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
- temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
+ temporal_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
- cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
+ cross_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
@@ -76,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
- spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
- temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
- cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
+ spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
+ temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
+ cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
@@ -227,7 +230,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
- if not isinstance(submodule, _ATTENTION_CLASSES):
+ if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
@@ -302,7 +305,7 @@ def _apply_pyramid_attention_broadcast_hook(
block_skip_range (`int`):
The number of times a specific attention broadcast is skipped before computing the attention states to
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
- attention states will be re-used) before computing the new attention states again.
+ attention states will be reused) before computing the new attention states again.
current_timestep_callback (`Callable[[], int]`):
A callback function that returns the current inference timestep.
"""
diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py
new file mode 100644
index 000000000000..622f60764762
--- /dev/null
+++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py
@@ -0,0 +1,167 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import asdict, dataclass
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+
+from ..utils import get_logger
+from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
+
+
+@dataclass
+class SmoothedEnergyGuidanceConfig:
+ r"""
+ Configuration for skipping internal transformer blocks when executing a transformer model.
+
+ Args:
+ indices (`List[int]`):
+ The indices of the layer to skip. This is typically the first layer in the transformer block.
+ fqn (`str`, defaults to `"auto"`):
+ The fully qualified name identifying the stack of transformer blocks. Typically, this is
+ `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
+ For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
+ provide the correct fqn.
+ _query_proj_identifiers (`List[str]`, defaults to `None`):
+ The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
+ `None`, `to_q` is used by default.
+ """
+
+ indices: List[int]
+ fqn: str = "auto"
+ _query_proj_identifiers: List[str] = None
+
+ def to_dict(self):
+ return asdict(self)
+
+ @staticmethod
+ def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
+ return SmoothedEnergyGuidanceConfig(**data)
+
+
+class SmoothedEnergyGuidanceHook(ModelHook):
+ def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
+ super().__init__()
+ self.blur_sigma = blur_sigma
+ self.blur_threshold_inf = blur_threshold_inf
+
+ def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
+ # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
+ kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
+ smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
+ return smoothed_output
+
+
+def _apply_smoothed_energy_guidance_hook(
+ module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None
+) -> None:
+ name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
+
+ if config.fqn == "auto":
+ for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
+ if hasattr(module, identifier):
+ config.fqn = identifier
+ break
+ else:
+ raise ValueError(
+ "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
+ "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
+ )
+
+ if config._query_proj_identifiers is None:
+ config._query_proj_identifiers = ["to_q"]
+
+ transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
+ blocks_found = False
+ for i, block in enumerate(transformer_blocks):
+ if i not in config.indices:
+ continue
+
+ blocks_found = True
+
+ for submodule_name, submodule in block.named_modules():
+ if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
+ continue
+ for identifier in config._query_proj_identifiers:
+ query_proj = getattr(submodule, identifier, None)
+ if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
+ continue
+ logger.debug(
+ f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
+ )
+ registry = HookRegistry.check_if_exists_or_initialize(query_proj)
+ hook = SmoothedEnergyGuidanceHook(blur_sigma)
+ registry.register_hook(hook, name)
+
+ if not blocks_found:
+ raise ValueError(
+ f"Could not find any transformer blocks matching the provided indices {config.indices} and "
+ f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
+ )
+
+
+# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
+def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
+ """
+ This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur.
+ However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this
+ implementation also assumes that the visual tokens come from a square image/video. In practice, despite these
+ assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for
+ Smoothed Energy Guidance.
+
+ SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
+ future without warning or guarantee of reproducibility.
+ """
+ assert query.ndim == 3
+
+ is_inf = sigma > sigma_threshold_inf
+ batch_size, seq_len, embed_dim = query.shape
+
+ seq_len_sqrt = int(math.sqrt(seq_len))
+ num_square_tokens = seq_len_sqrt * seq_len_sqrt
+ query_slice = query[:, :num_square_tokens, :]
+ query_slice = query_slice.permute(0, 2, 1)
+ query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
+
+ if is_inf:
+ kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
+ kernel_size_half = (kernel_size - 1) / 2
+
+ x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
+ kernel1d = pdf / pdf.sum()
+ kernel1d = kernel1d.to(query)
+ kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
+ kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
+
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
+ query_slice = F.pad(query_slice, padding, mode="reflect")
+ query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
+ else:
+ query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
+
+ query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
+ query_slice = query_slice.permute(0, 2, 1)
+ query[:, :num_square_tokens, :] = query_slice.clone()
+
+ return query
diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py
new file mode 100644
index 000000000000..7cad9f4fa161
--- /dev/null
+++ b/src/diffusers/hooks/taylorseer_cache.py
@@ -0,0 +1,346 @@
+import math
+import re
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from ..utils import logging
+from .hooks import HookRegistry, ModelHook, StateManager
+
+
+logger = logging.get_logger(__name__)
+_TAYLORSEER_CACHE_HOOK = "taylorseer_cache"
+_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
+ "^blocks.*attn",
+ "^transformer_blocks.*attn",
+ "^single_transformer_blocks.*attn",
+)
+_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
+_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
+_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",)
+_PROJ_OUT_IDENTIFIERS = ("^proj_out$",)
+
+
+@dataclass
+class TaylorSeerCacheConfig:
+ """
+ Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923
+
+ Attributes:
+ cache_interval (`int`, defaults to `5`):
+ The interval between full computation steps. After a full computation, the cached (predicted) outputs are
+ reused for this many subsequent denoising steps before refreshing with a new full forward pass.
+
+ disable_cache_before_step (`int`, defaults to `3`):
+ The denoising step index before which caching is disabled, meaning full computation is performed for the
+ initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During
+ these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this
+ step.
+
+ disable_cache_after_step (`int`, *optional*, defaults to `None`):
+ The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run
+ full computations without predictions or state updates, ensuring accuracy in later stages if needed.
+
+ max_order (`int`, defaults to `1`):
+ The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide
+ better approximations but increase computation and memory usage.
+
+ taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
+ Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may
+ affect stability; higher precision improves accuracy at the cost of more memory.
+
+ skip_predict_identifiers (`List[str]`, *optional*, defaults to `None`):
+ Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode,
+ the module computes fully during initial or refresh steps but returns a zero tensor (matching recorded
+ shape) during prediction steps to skip computation cheaply.
+
+ cache_identifiers (`List[str]`, *optional*, defaults to `None`):
+ Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where
+ outputs are approximated and cached for reuse.
+
+ use_lite_mode (`bool`, *optional*, defaults to `False`):
+ Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for
+ skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom
+ `inactive_identifiers` or `active_identifiers`.
+
+ Notes:
+ - Patterns are matched using `re.fullmatch` on the module name.
+ - If `skip_predict_identifiers` or `cache_identifiers` are provided, only matching modules are hooked.
+ - If neither is provided, all attention-like modules are hooked by default.
+
+ Example of inactive and active usage:
+
+ ```py
+ def forward(x):
+ x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute
+ x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps
+ return x
+ ```
+ """
+
+ cache_interval: int = 5
+ disable_cache_before_step: int = 3
+ disable_cache_after_step: Optional[int] = None
+ max_order: int = 1
+ taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16
+ skip_predict_identifiers: Optional[List[str]] = None
+ cache_identifiers: Optional[List[str]] = None
+ use_lite_mode: bool = False
+
+ def __repr__(self) -> str:
+ return (
+ "TaylorSeerCacheConfig("
+ f"cache_interval={self.cache_interval}, "
+ f"disable_cache_before_step={self.disable_cache_before_step}, "
+ f"disable_cache_after_step={self.disable_cache_after_step}, "
+ f"max_order={self.max_order}, "
+ f"taylor_factors_dtype={self.taylor_factors_dtype}, "
+ f"skip_predict_identifiers={self.skip_predict_identifiers}, "
+ f"cache_identifiers={self.cache_identifiers}, "
+ f"use_lite_mode={self.use_lite_mode})"
+ )
+
+
+class TaylorSeerState:
+ def __init__(
+ self,
+ taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16,
+ max_order: int = 1,
+ is_inactive: bool = False,
+ ):
+ self.taylor_factors_dtype = taylor_factors_dtype
+ self.max_order = max_order
+ self.is_inactive = is_inactive
+
+ self.module_dtypes: Tuple[torch.dtype, ...] = ()
+ self.last_update_step: Optional[int] = None
+ self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {}
+ self.inactive_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None
+ self.device: Optional[torch.device] = None
+ self.current_step: int = -1
+
+ def reset(self) -> None:
+ self.current_step = -1
+ self.last_update_step = None
+ self.taylor_factors = {}
+ self.inactive_shapes = None
+ self.device = None
+
+ def update(
+ self,
+ outputs: Tuple[torch.Tensor, ...],
+ ) -> None:
+ self.module_dtypes = tuple(output.dtype for output in outputs)
+ self.device = outputs[0].device
+
+ if self.is_inactive:
+ self.inactive_shapes = tuple(output.shape for output in outputs)
+ else:
+ for i, features in enumerate(outputs):
+ new_factors: Dict[int, torch.Tensor] = {0: features}
+ is_first_update = self.last_update_step is None
+ if not is_first_update:
+ delta_step = self.current_step - self.last_update_step
+ if delta_step == 0:
+ raise ValueError("Delta step cannot be zero for TaylorSeer update.")
+
+ # Recursive divided differences up to max_order
+ prev_factors = self.taylor_factors.get(i, {})
+ for j in range(self.max_order):
+ prev = prev_factors.get(j)
+ if prev is None:
+ break
+ new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step
+ self.taylor_factors[i] = {
+ order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()
+ }
+
+ self.last_update_step = self.current_step
+
+ @torch.compiler.disable
+ def predict(self) -> List[torch.Tensor]:
+ if self.last_update_step is None:
+ raise ValueError("Cannot predict without prior initialization/update.")
+
+ step_offset = self.current_step - self.last_update_step
+
+ outputs = []
+ if self.is_inactive:
+ if self.inactive_shapes is None:
+ raise ValueError("Inactive shapes not set during prediction.")
+ for i in range(len(self.module_dtypes)):
+ outputs.append(
+ torch.zeros(
+ self.inactive_shapes[i],
+ dtype=self.module_dtypes[i],
+ device=self.device,
+ )
+ )
+ else:
+ if not self.taylor_factors:
+ raise ValueError("Taylor factors empty during prediction.")
+ num_outputs = len(self.taylor_factors)
+ num_orders = len(self.taylor_factors[0])
+ for i in range(num_outputs):
+ output_dtype = self.module_dtypes[i]
+ taylor_factors = self.taylor_factors[i]
+ output = torch.zeros_like(taylor_factors[0], dtype=output_dtype)
+ for order in range(num_orders):
+ coeff = (step_offset**order) / math.factorial(order)
+ factor = taylor_factors[order]
+ output = output + factor.to(output_dtype) * coeff
+ outputs.append(output)
+ return outputs
+
+
+class TaylorSeerCacheHook(ModelHook):
+ _is_stateful = True
+
+ def __init__(
+ self,
+ cache_interval: int,
+ disable_cache_before_step: int,
+ taylor_factors_dtype: torch.dtype,
+ state_manager: StateManager,
+ disable_cache_after_step: Optional[int] = None,
+ ):
+ super().__init__()
+ self.cache_interval = cache_interval
+ self.disable_cache_before_step = disable_cache_before_step
+ self.disable_cache_after_step = disable_cache_after_step
+ self.taylor_factors_dtype = taylor_factors_dtype
+ self.state_manager = state_manager
+
+ def initialize_hook(self, module: torch.nn.Module):
+ return module
+
+ def reset_state(self, module: torch.nn.Module) -> None:
+ """
+ Reset state between sampling runs.
+ """
+ self.state_manager.reset()
+
+ @torch.compiler.disable
+ def _measure_should_compute(self) -> bool:
+ state: TaylorSeerState = self.state_manager.get_state()
+ state.current_step += 1
+ current_step = state.current_step
+ is_warmup_phase = current_step < self.disable_cache_before_step
+ is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0
+ is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
+ should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase
+ return should_compute, state
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ should_compute, state = self._measure_should_compute()
+ if should_compute:
+ outputs = self.fn_ref.original_forward(*args, **kwargs)
+ wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
+ state.update(wrapped_outputs)
+ return outputs
+
+ outputs_list = state.predict()
+ return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list)
+
+
+def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]:
+ """
+ Resolve effective inactive and active pattern lists from config + templates.
+ """
+
+ inactive_patterns = config.skip_predict_identifiers if config.skip_predict_identifiers is not None else None
+ active_patterns = config.cache_identifiers if config.cache_identifiers is not None else None
+
+ return inactive_patterns or [], active_patterns or []
+
+
+def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
+ """
+ Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet).
+
+ This function hooks selected modules in the model to enable caching or skipping based on the provided
+ configuration, reducing redundant computations in diffusion denoising loops.
+
+ Args:
+ module (torch.nn.Module): The model subtree to apply the hooks to.
+ config (TaylorSeerCacheConfig): Configuration for the cache.
+
+ Example:
+ ```python
+ >>> import torch
+ >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig
+
+ >>> pipe = FluxPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-dev",
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> config = TaylorSeerCacheConfig(
+ ... cache_interval=5,
+ ... max_order=1,
+ ... disable_cache_before_step=3,
+ ... taylor_factors_dtype=torch.float32,
+ ... )
+ >>> pipe.transformer.enable_cache(config)
+ ```
+ """
+ inactive_patterns, active_patterns = _resolve_patterns(config)
+
+ active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
+
+ if config.use_lite_mode:
+ logger.info("Using TaylorSeer Lite variant for cache.")
+ active_patterns = _PROJ_OUT_IDENTIFIERS
+ inactive_patterns = _BLOCK_IDENTIFIERS
+ if config.skip_predict_identifiers or config.cache_identifiers:
+ logger.warning("Lite mode overrides user patterns.")
+
+ for name, submodule in module.named_modules():
+ matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns)
+ matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns)
+ if not (matches_inactive or matches_active):
+ continue
+ _apply_taylorseer_cache_hook(
+ module=submodule,
+ config=config,
+ is_inactive=matches_inactive,
+ )
+
+
+def _apply_taylorseer_cache_hook(
+ module: nn.Module,
+ config: TaylorSeerCacheConfig,
+ is_inactive: bool,
+):
+ """
+ Registers the TaylorSeer hook on the specified nn.Module.
+
+ Args:
+ name: Name of the module.
+ module: The nn.Module to be hooked.
+ config: Cache configuration.
+ is_inactive: Whether this module should operate in "inactive" mode.
+ """
+ state_manager = StateManager(
+ TaylorSeerState,
+ init_kwargs={
+ "taylor_factors_dtype": config.taylor_factors_dtype,
+ "max_order": config.max_order,
+ "is_inactive": is_inactive,
+ },
+ )
+
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+
+ hook = TaylorSeerCacheHook(
+ cache_interval=config.cache_interval,
+ disable_cache_before_step=config.disable_cache_before_step,
+ taylor_factors_dtype=config.taylor_factors_dtype,
+ disable_cache_after_step=config.disable_cache_after_step,
+ state_manager=state_manager,
+ )
+
+ registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
diff --git a/src/diffusers/hooks/utils.py b/src/diffusers/hooks/utils.py
new file mode 100644
index 000000000000..c5260eeebe1f
--- /dev/null
+++ b/src/diffusers/hooks/utils.py
@@ -0,0 +1,43 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
+
+
+def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
+ module_list_with_transformer_blocks = []
+ for name, submodule in module.named_modules():
+ name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
+ is_modulelist = isinstance(submodule, torch.nn.ModuleList)
+ if name_endswith_identifier and is_modulelist:
+ module_list_with_transformer_blocks.append((name, submodule))
+ return module_list_with_transformer_blocks
+
+
+def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
+ attention_layers = []
+ for name, submodule in module.named_modules():
+ if isinstance(submodule, _ATTENTION_CLASSES):
+ attention_layers.append((name, submodule))
+ return attention_layers
+
+
+def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
+ feedforward_layers = []
+ for name, submodule in module.named_modules():
+ if isinstance(submodule, _FEEDFORWARD_CLASSES):
+ feedforward_layers.append((name, submodule))
+ return feedforward_layers
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index d6913f045ad2..abd0a25819f5 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -116,6 +116,7 @@ def __init__(
vae_scale_factor: int = 8,
vae_latent_channels: int = 4,
resample: str = "lanczos",
+ reducing_gap: int = None,
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False,
@@ -408,7 +409,7 @@ def _resize_and_fill(
src_w = width if ratio < src_ratio else image.width * height // image.height
src_h = height if ratio >= src_ratio else image.height * width // image.width
- resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
@@ -459,7 +460,7 @@ def _resize_and_crop(
src_w = width if ratio > src_ratio else image.width * height // image.height
src_h = height if ratio <= src_ratio else image.height * width // image.width
- resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
res = Image.new("RGB", (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
return res
@@ -498,7 +499,11 @@ def resize(
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
if isinstance(image, PIL.Image.Image):
if resize_mode == "default":
- image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
+ image = image.resize(
+ (width, height),
+ resample=PIL_INTERPOLATION[self.config.resample],
+ reducing_gap=self.config.reducing_gap,
+ )
elif resize_mode == "fill":
image = self._resize_and_fill(image, width, height)
elif resize_mode == "crop":
@@ -518,6 +523,7 @@ def resize(
size=(height, width),
)
image = self.pt_to_numpy(image)
+
return image
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
@@ -833,6 +839,137 @@ def apply_overlay(
return image
+class InpaintProcessor(ConfigMixin):
+ """
+ Image processor for inpainting image and mask.
+ """
+
+ config_name = CONFIG_NAME
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 8,
+ vae_latent_channels: int = 4,
+ resample: str = "lanczos",
+ reducing_gap: int = None,
+ do_normalize: bool = True,
+ do_binarize: bool = False,
+ do_convert_grayscale: bool = False,
+ mask_do_normalize: bool = False,
+ mask_do_binarize: bool = True,
+ mask_do_convert_grayscale: bool = True,
+ ):
+ super().__init__()
+
+ self._image_processor = VaeImageProcessor(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ resample=resample,
+ reducing_gap=reducing_gap,
+ do_normalize=do_normalize,
+ do_binarize=do_binarize,
+ do_convert_grayscale=do_convert_grayscale,
+ )
+ self._mask_processor = VaeImageProcessor(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ resample=resample,
+ reducing_gap=reducing_gap,
+ do_normalize=mask_do_normalize,
+ do_binarize=mask_do_binarize,
+ do_convert_grayscale=mask_do_convert_grayscale,
+ )
+
+ def preprocess(
+ self,
+ image: PIL.Image.Image,
+ mask: PIL.Image.Image = None,
+ height: int = None,
+ width: int = None,
+ padding_mask_crop: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Preprocess the image and mask.
+ """
+ if mask is None and padding_mask_crop is not None:
+ raise ValueError("mask must be provided if padding_mask_crop is provided")
+
+ # if mask is None, same behavior as regular image processor
+ if mask is None:
+ return self._image_processor.preprocess(image, height=height, width=width)
+
+ if padding_mask_crop is not None:
+ crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ processed_image = self._image_processor.preprocess(
+ image,
+ height=height,
+ width=width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+
+ processed_mask = self._mask_processor.preprocess(
+ mask,
+ height=height,
+ width=width,
+ resize_mode=resize_mode,
+ crops_coords=crops_coords,
+ )
+
+ if crops_coords is not None:
+ postprocessing_kwargs = {
+ "crops_coords": crops_coords,
+ "original_image": image,
+ "original_mask": mask,
+ }
+ else:
+ postprocessing_kwargs = {
+ "crops_coords": None,
+ "original_image": None,
+ "original_mask": None,
+ }
+
+ return processed_image, processed_mask, postprocessing_kwargs
+
+ def postprocess(
+ self,
+ image: torch.Tensor,
+ output_type: str = "pil",
+ original_image: Optional[PIL.Image.Image] = None,
+ original_mask: Optional[PIL.Image.Image] = None,
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
+ ) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
+ """
+ Postprocess the image, optionally apply mask overlay
+ """
+ image = self._image_processor.postprocess(
+ image,
+ output_type=output_type,
+ )
+ # optionally apply the mask overlay
+ if crops_coords is not None and (original_image is None or original_mask is None):
+ raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
+
+ elif crops_coords is not None and output_type != "pil":
+ raise ValueError("output_type must be 'pil' if crops_coords is provided")
+
+ elif crops_coords is not None:
+ image = [
+ self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
+ ]
+
+ return image
+
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
"""
Image processor for VAE LDM3D.
@@ -908,16 +1045,39 @@ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) ->
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Convert an RGB-like depth image to a depth map.
+ """
+ # 1. Cast the tensor to a larger integer type (e.g., int32)
+ # to safely perform the multiplication by 256.
+ # 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
+ # 3. Cast the final result to the desired depth map type (uint16) if needed
+ # before returning, though leaving it as int32/int64 is often safer
+ # for return value from a library function.
+
+ if isinstance(image, torch.Tensor):
+ # Cast to a safe dtype (e.g., int32 or int64) for the calculation
+ original_dtype = image.dtype
+ image_safe = image.to(torch.int32)
+
+ # Calculate the depth map
+ depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
+
+ # You may want to cast the final result to uint16, but casting to a
+ # larger int type (like int32) is sufficient to fix the overflow.
+ # depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
+ return depth_map.to(original_dtype)
- Args:
- image (`Union[np.ndarray, torch.Tensor]`):
- The RGB-like depth image to convert.
+ elif isinstance(image, np.ndarray):
+ # NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
+ original_dtype = image.dtype
+ image_safe = image.astype(np.int32)
- Returns:
- `Union[np.ndarray, torch.Tensor]`:
- The corresponding depth map.
- """
- return image[:, :, 1] * 2**8 + image[:, :, 2]
+ # Calculate the depth map
+ depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
+
+ # depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
+ return depth_map.astype(original_dtype)
+ else:
+ raise TypeError("Input image must be a torch.Tensor or np.ndarray")
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
r"""
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index 3ba1bfacf3dd..ace4e8543a1c 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -65,6 +65,7 @@ def text_encoder_attn_modules(text_encoder):
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
+ "AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
@@ -76,12 +77,19 @@ def text_encoder_attn_modules(text_encoder):
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
+ "KandinskyLoraLoaderMixin",
+ "HiDreamImageLoraLoaderMixin",
+ "SkyReelsV2LoraLoaderMixin",
+ "QwenImageLoraLoaderMixin",
+ "ZImageLoraLoaderMixin",
+ "Flux2LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
+ "ModularIPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -99,23 +107,31 @@ def text_encoder_attn_modules(text_encoder):
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
+ ModularIPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
+ AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
+ Flux2LoraLoaderMixin,
FluxLoraLoaderMixin,
+ HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
+ KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,
+ QwenImageLoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
+ SkyReelsV2LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
+ ZImageLoraLoaderMixin,
)
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index 21a1a70ff79b..dca4758ba038 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -40,8 +40,6 @@
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
- FluxAttnProcessor2_0,
- FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
@@ -159,10 +157,7 @@ def load_ip_adapter(
" `low_cpu_mem_usage=False`."
)
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -295,8 +290,7 @@ def set_ip_adapter_scale(self, scale):
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
- f"Cannot assign {len(scale_configs)} scale_configs to "
- f"{len(attn_processor.scale)} IP-Adapter."
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
@@ -358,6 +352,256 @@ def unload_ip_adapter(self):
self.unet.set_attn_processor(attn_procs)
+class ModularIPAdapterMixin:
+ """Mixin for handling IP Adapters."""
+
+ @validate_hf_hub_args
+ def load_ip_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
+ subfolder: Union[str, List[str]],
+ weight_name: Union[str, List[str]],
+ **kwargs,
+ ):
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ subfolder (`str` or `List[str]`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
+ weight_name (`str` or `List[str]`):
+ The name of the weight file to load. If a list is passed, it should have the same length as
+ `subfolder`.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+
+ # handle the list inputs for multiple IP Adapters
+ if not isinstance(weight_name, list):
+ weight_name = [weight_name]
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
+ if len(pretrained_model_name_or_path_or_dict) == 1:
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
+
+ if not isinstance(subfolder, list):
+ subfolder = [subfolder]
+ if len(subfolder) == 1:
+ subfolder = subfolder * len(weight_name)
+
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
+
+ if len(weight_name) != len(subfolder):
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
+
+ # Load the main state dict first.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+ state_dicts = []
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
+ ):
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if "image_proj" not in keys and "ip_adapter" not in keys:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+ state_dicts.append(state_dict)
+
+ unet_name = getattr(self, "unet_name", "unet")
+ unet = getattr(self, unet_name)
+ unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ extra_loras = unet._load_ip_adapter_loras(state_dicts)
+ if extra_loras != {}:
+ if not USE_PEFT_BACKEND:
+ logger.warning("PEFT backend is required to load these weights.")
+ else:
+ # apply the IP Adapter Face ID LoRA weights
+ peft_config = getattr(unet, "peft_config", {})
+ for k, lora in extra_loras.items():
+ if f"faceid_{k}" not in peft_config:
+ self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
+ self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
+
+ def set_ip_adapter_scale(self, scale):
+ """
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
+ granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
+
+ Example:
+
+ ```py
+ # To use original IP-Adapter
+ scale = 1.0
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style block only
+ scale = {
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+ }
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style+layout blocks
+ scale = {
+ "down": {"block_2": [0.0, 1.0]},
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+ }
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style and layout from 2 reference images
+ scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
+ pipeline.set_ip_adapter_scale(scales)
+ ```
+ """
+ unet_name = getattr(self, "unet_name", "unet")
+ unet = getattr(self, unet_name)
+ if not isinstance(scale, list):
+ scale = [scale]
+ scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
+
+ for attn_name, attn_processor in unet.attn_processors.items():
+ if isinstance(
+ attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
+ ):
+ if len(scale_configs) != len(attn_processor.scale):
+ raise ValueError(
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
+ )
+ elif len(scale_configs) == 1:
+ scale_configs = scale_configs * len(attn_processor.scale)
+ for i, scale_config in enumerate(scale_configs):
+ if isinstance(scale_config, dict):
+ for k, s in scale_config.items():
+ if attn_name.startswith(k):
+ attn_processor.scale[i] = s
+ else:
+ attn_processor.scale[i] = scale_config
+
+ def unload_ip_adapter(self):
+ """
+ Unloads the IP Adapter weights
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.unload_ip_adapter()
+ >>> ...
+ ```
+ """
+
+ # remove hidden encoder
+ if self.unet is None:
+ return
+
+ self.unet.encoder_hid_proj = None
+ self.unet.config.encoder_hid_dim_type = None
+
+ # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
+ if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
+ self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
+ self.unet.text_encoder_hid_proj = None
+ self.unet.config.encoder_hid_dim_type = "text_proj"
+
+ # restore original Unet attention processors layers
+ attn_procs = {}
+ for name, value in self.unet.attn_processors.items():
+ attn_processor_class = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
+ )
+ attn_procs[name] = (
+ attn_processor_class
+ if isinstance(
+ value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
+ )
+ else value.__class__()
+ )
+ self.unet.set_attn_processor(attn_procs)
+
+
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@@ -466,10 +710,7 @@ def load_ip_adapter(
" `low_cpu_mem_usage=False`."
)
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -527,7 +768,7 @@ def load_ip_adapter(
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
- dtype=image_encoder_dtype,
+ torch_dtype=image_encoder_dtype,
)
.to(self.device)
.eval()
@@ -624,6 +865,9 @@ def unload_ip_adapter(self):
>>> ...
```
"""
+ # TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
+ from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
+
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
@@ -643,9 +887,9 @@ def unload_ip_adapter(self):
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
- attn_processor_class = FluxAttnProcessor2_0()
+ attn_processor_class = FluxAttnProcessor()
attn_procs[name] = (
- attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
+ attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
@@ -751,10 +995,7 @@ def load_ip_adapter(
" `low_cpu_mem_usage=False`."
)
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
index 17ed8c5444fc..3d75a7d875a4 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
import copy
import inspect
+import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
@@ -33,7 +34,6 @@
delete_adapter_layers,
deprecate,
get_adapter_name,
- get_peft_kwargs,
is_accelerate_available,
is_peft_available,
is_peft_version,
@@ -45,13 +45,13 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
+from ..utils.peft_utils import _create_lora_config
+from ..utils.state_dict_utils import _load_sft_state_dict_metadata
if is_transformers_available():
from transformers import PreTrainedModel
- from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
-
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -62,6 +62,7 @@
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
+LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
@@ -206,6 +207,7 @@ def _fetch_state_dict(
subfolder,
user_agent,
allow_pickle,
+ metadata=None,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -236,11 +238,14 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
+ metadata = _load_sft_state_dict_metadata(model_file)
+
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
+ metadata = None
pass
if model_file is None:
@@ -261,10 +266,11 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
+ metadata = None
else:
state_dict = pretrained_model_name_or_path_or_dict
- return state_dict
+ return state_dict, metadata
def _best_guess_weight_name(
@@ -299,13 +305,18 @@ def _best_guess_weight_name(
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
if len(targeted_files) > 1:
- raise ValueError(
- f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
+ logger.warning(
+ f"Provided path contains more than one weights file in the {file_extension} format. `{targeted_files[0]}` is going to be loaded, for precise control, specify a `weight_name` in `load_lora_weights`."
)
weight_name = targeted_files[0]
return weight_name
+def _pack_dict_with_prefix(state_dict, prefix):
+ sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
+ return sd_with_prefix
+
+
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
@@ -316,10 +327,17 @@ def _load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
+
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
+ if network_alphas and metadata:
+ raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
+
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
@@ -334,16 +352,20 @@ def _load_lora_into_text_encoder(
)
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
- from peft import LoraConfig
-
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
prefix = text_encoder_name if prefix is None else prefix
+ # Safe prefix to check with.
+ if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
+ raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
+
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
- state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
+ state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
+ if metadata is not None:
+ metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
@@ -353,54 +375,27 @@ def _load_lora_into_text_encoder(
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)
- for name, _ in text_encoder_attn_modules(text_encoder):
- for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in state_dict:
- continue
- rank[rank_key] = state_dict[rank_key].shape[1]
-
- for name, _ in text_encoder_mlp_modules(text_encoder):
- for module in ("fc1", "fc2"):
- rank_key = f"{name}.{module}.lora_B.weight"
- if rank_key not in state_dict:
- continue
- rank[rank_key] = state_dict[rank_key].shape[1]
+ for name, _ in text_encoder.named_modules():
+ if name.endswith((".q_proj", ".k_proj", ".v_proj", ".out_proj", ".fc1", ".fc2")):
+ rank_key = f"{name}.lora_B.weight"
+ if rank_key in state_dict:
+ rank[rank_key] = state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
-
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"]:
- if is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<", "0.9.0"):
- lora_config_kwargs.pop("use_dora")
-
- if "lora_bias" in lora_config_kwargs:
- if lora_config_kwargs["lora_bias"]:
- if is_peft_version("<=", "0.13.2"):
- raise ValueError(
- "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<=", "0.13.2"):
- lora_config_kwargs.pop("lora_bias")
+ network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
- lora_config = LoraConfig(**lora_config_kwargs)
+ # create `LoraConfig`
+ lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank, is_unet=False)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
- is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
-
+ #
if prefix is not None and not state_dict:
+ model_class_name = text_encoder.__class__.__name__
logger.warning(
- f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
+ f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
- f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
+ f"{model_class_name} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
def _func_optionally_disable_offloading(_pipeline):
+ """
+ Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
+
+ Args:
+ _pipeline (`DiffusionPipeline`):
+ The pipeline to disable offloading for.
+
+ Returns:
+ tuple:
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
+ """
+ from ..hooks.group_offloading import _is_group_offload_enabled
+
is_model_cpu_offload = False
is_sequential_cpu_offload = False
+ is_group_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
- if not is_model_cpu_offload:
- is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
- if not is_sequential_cpu_offload:
- is_sequential_cpu_offload = (
- isinstance(component._hf_hook, AlignDevicesHook)
- or hasattr(component._hf_hook, "hooks")
- and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
- )
+ if not isinstance(component, nn.Module):
+ continue
+ is_group_offload = is_group_offload or _is_group_offload_enabled(component)
+ if not hasattr(component, "_hf_hook"):
+ continue
+ is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
+ is_sequential_cpu_offload = is_sequential_cpu_offload or (
+ isinstance(component._hf_hook, AlignDevicesHook)
+ or hasattr(component._hf_hook, "hooks")
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
+ )
- logger.info(
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
- )
+ if is_sequential_cpu_offload or is_model_cpu_offload:
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+ )
+ for _, component in _pipeline.components.items():
+ if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
+ continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
- return (is_model_cpu_offload, is_sequential_cpu_offload)
+ return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
_lora_loadable_modules = []
- num_fused_loras = 0
+ _merged_adapters = set()
+
+ @property
+ def lora_scale(self) -> float:
+ """
+ Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
+ return 1.
+ """
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
+
+ @property
+ def num_fused_loras(self):
+ """Returns the number of LoRAs that have been fused."""
+ return len(self._merged_adapters)
+
+ @property
+ def fused_loras(self):
+ """Returns names of the LoRAs that have been fused."""
+ return self._merged_adapters
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
@@ -473,33 +510,6 @@ def save_lora_weights(cls, **kwargs):
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")
- @classmethod
- def _optionally_disable_offloading(cls, _pipeline):
- """
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
-
- Args:
- _pipeline (`DiffusionPipeline`):
- The pipeline to disable offloading for.
-
- Returns:
- tuple:
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
- """
- return _func_optionally_disable_offloading(_pipeline=_pipeline)
-
- @classmethod
- def _fetch_state_dict(cls, *args, **kwargs):
- deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
- deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
- return _fetch_state_dict(*args, **kwargs)
-
- @classmethod
- def _best_guess_weight_name(cls, *args, **kwargs):
- deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
- deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
- return _best_guess_weight_name(*args, **kwargs)
-
def unload_lora_weights(self):
"""
Unloads the LoRA parameters.
@@ -534,11 +544,7 @@ def fuse_lora(
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
@@ -587,6 +593,9 @@ def fuse_lora(
if len(components) == 0:
raise ValueError("`components` cannot be an empty list.")
+ # Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
+ # in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
+ merged_adapter_names = set()
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
@@ -596,24 +605,26 @@ def fuse_lora(
# check if diffusers model
if issubclass(model.__class__, ModelMixin):
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ merged_adapter_names.update(set(module.merged_adapters))
# handle transformers models.
if issubclass(model.__class__, PreTrainedModel):
fuse_text_encoder_lora(
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
+ for module in model.modules():
+ if isinstance(module, BaseTunerLayer):
+ merged_adapter_names.update(set(module.merged_adapters))
- self.num_fused_loras += 1
+ self._merged_adapters = self._merged_adapters | merged_adapter_names
def unfuse_lora(self, components: List[str] = [], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -656,15 +667,42 @@ def unfuse_lora(self, components: List[str] = [], **kwargs):
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
+ for adapter in set(module.merged_adapters):
+ if adapter and adapter in self._merged_adapters:
+ self._merged_adapters = self._merged_adapters - {adapter}
module.unmerge()
- self.num_fused_loras -= 1
-
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
+ """
+ Set the currently active adapters for use in the pipeline.
+
+ Args:
+ adapter_names (`List[str]` or `str`):
+ The names of the adapters to use.
+ adapter_weights (`Union[List[float], float]`, *optional*):
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
+ adapters.
+
+ Example:
+
+ ```py
+ from diffusers import AutoPipelineForText2Image
+ import torch
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights(
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+ )
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
+ ```
+ """
if isinstance(adapter_weights, dict):
components_passed = set(adapter_weights.keys())
lora_components = set(self._lora_loadable_modules)
@@ -708,7 +746,11 @@ def set_adapters(
# Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {}
for component in self._lora_loadable_modules:
- model = getattr(self, component)
+ model = getattr(self, component, None)
+ # To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
+ # Whereas in Wan 2.2, we have two denoisers.
+ if model is None:
+ continue
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
@@ -734,6 +776,24 @@ def set_adapters(
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
def disable_lora(self):
+ """
+ Disables the active LoRA layers of the pipeline.
+
+ Example:
+
+ ```py
+ from diffusers import AutoPipelineForText2Image
+ import torch
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights(
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+ )
+ pipeline.disable_lora()
+ ```
+ """
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -746,6 +806,24 @@ def disable_lora(self):
disable_lora_for_text_encoder(model)
def enable_lora(self):
+ """
+ Enables the active LoRA layers of the pipeline.
+
+ Example:
+
+ ```py
+ from diffusers import AutoPipelineForText2Image
+ import torch
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights(
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
+ )
+ pipeline.enable_lora()
+ ```
+ """
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -759,10 +837,26 @@ def enable_lora(self):
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
+ Delete an adapter's LoRA layers from the pipeline.
+
Args:
- Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
adapter_names (`Union[List[str], str]`):
- The names of the adapter to delete. Can be a single string or a list of strings
+ The names of the adapters to delete.
+
+ Example:
+
+ ```py
+ from diffusers import AutoPipelineForText2Image
+ import torch
+
+ pipeline = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights(
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
+ )
+ pipeline.delete_adapters("cinematic")
+ ```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -839,6 +933,27 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
+ After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
+ can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
+ GPU before using those LoRA adapters for inference.
+
+ ```python
+ >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
+ >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
+ >>> pipe.set_adapters("adapter-1")
+ >>> image_1 = pipe(**kwargs)
+ >>> # switch to adapter-2, offload adapter-1
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
+ >>> pipe.set_adapters("adapter-2")
+ >>> image_2 = pipe(**kwargs)
+ >>> # switch back to adapter-1, offload adapter-2
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
+ >>> pipe.set_adapters("adapter-1")
+ >>> ...
+ ```
+
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
@@ -854,6 +969,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
+ if adapter_name not in module.lora_A:
+ # it is sufficient to check lora_A
+ continue
+
module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
@@ -863,11 +982,28 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
adapter_name
].to(device)
+ def enable_lora_hotswap(self, **kwargs) -> None:
+ """
+ Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
+ different.
+
+ Args:
+ target_rank (`int`):
+ The highest rank among all the adapters that will be loaded.
+ check_compiled (`str`, *optional*, defaults to `"error"`):
+ How to handle a model that is already compiled. The check can return the following messages:
+ - "error" (default): raise an error
+ - "warn": issue a warning
+ - "ignore": do nothing
+ """
+ for key, component in self.components.items():
+ if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
+ component.enable_lora_hotswap(**kwargs)
+
@staticmethod
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
- layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
- return layers_state_dict
+ return _pack_dict_with_prefix(layers_weights, prefix)
@staticmethod
def write_lora_layers(
@@ -877,16 +1013,33 @@ def write_lora_layers(
weight_name: str,
save_function: Callable,
safe_serialization: bool,
+ lora_adapter_metadata: Optional[dict] = None,
):
+ """Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
+ if lora_adapter_metadata and not safe_serialization:
+ raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
+ if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
+ raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
+
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+ # Inject framework format.
+ metadata = {"format": "pt"}
+ if lora_adapter_metadata:
+ for key, value in lora_adapter_metadata.items():
+ if isinstance(value, set):
+ lora_adapter_metadata[key] = list(value)
+ metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
+ lora_adapter_metadata, indent=2, sort_keys=True
+ )
+
+ return safetensors.torch.save_file(weights, filename, metadata=metadata)
else:
save_function = torch.save
@@ -903,8 +1056,41 @@ def save_function(weights, filename):
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
- @property
- def lora_scale(self) -> float:
- # property function that returns the lora scale which can be set at run time by the pipeline.
- # if _lora_scale has not been set, return 1
- return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
+ @classmethod
+ def _save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
+ lora_metadata: Dict[str, Optional[dict]],
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ """
+ Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
+ pipeline types.
+ """
+ state_dict = {}
+ final_lora_adapter_metadata = {}
+
+ for prefix, layers in lora_layers.items():
+ state_dict.update(cls.pack_weights(layers, prefix))
+
+ for prefix, metadata in lora_metadata.items():
+ if metadata:
+ final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
+
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
+ )
+
+ @classmethod
+ def _optionally_disable_offloading(cls, _pipeline):
+ return _func_optionally_disable_offloading(_pipeline=_pipeline)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 20fcb61f3b80..2e87f757c352 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,19 +13,44 @@
# limitations under the License.
import re
+from typing import List
import torch
-from ..utils import is_peft_version, logging
+from ..utils import is_peft_version, logging, state_dict_all_zero
logger = logging.get_logger(__name__)
+def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
# 1. get all state_dict_keys
all_keys = list(state_dict.keys())
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
+ not_sgm_patterns = ["down_blocks", "mid_block", "up_blocks"]
+
+ # check if state_dict contains both patterns
+ contains_sgm_patterns = False
+ contains_not_sgm_patterns = False
+ for key in all_keys:
+ if any(p in key for p in sgm_patterns):
+ contains_sgm_patterns = True
+ elif any(p in key for p in not_sgm_patterns):
+ contains_not_sgm_patterns = True
+
+ # if state_dict contains both patterns, remove sgm
+ # we can then return state_dict immediately
+ if contains_sgm_patterns and contains_not_sgm_patterns:
+ for key in all_keys:
+ if any(p in key for p in sgm_patterns):
+ state_dict.pop(key)
+ return state_dict
# 2. check if needs remapping, if not return original dict
is_in_sgm_format = False
@@ -119,7 +144,7 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
)
new_state_dict[new_key] = state_dict.pop(key)
- if len(state_dict) > 0:
+ if state_dict:
raise ValueError("At this point all state dict entries have to be converted.")
return new_state_dict
@@ -177,9 +202,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
# Store DoRA scale if present.
if dora_present_in_unet:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
- unet_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ unet_state_dict[diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
# Handle text encoder LoRAs.
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
@@ -199,13 +224,13 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
)
if lora_name.startswith(("lora_te_", "lora_te1_")):
- te_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ te_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
elif lora_name.startswith("lora_te2_"):
- te2_state_dict[
- diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
- ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ te2_state_dict[diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")] = (
+ state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
+ )
# Store alpha if present.
if lora_name_alpha in state_dict:
@@ -313,6 +338,7 @@ def _convert_text_encoder_lora_key(key, lora_name):
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
+
return diffusers_name
@@ -331,8 +357,7 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
-# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
-# All credits go to `kohya-ss`.
+# are adapted from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
@@ -341,7 +366,8 @@ def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
# scale weight by alpha and dim
rank = down_weight.shape[0]
- alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
+ default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
+ alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
@@ -362,7 +388,10 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
- alpha = sds_sd.pop(sds_key + ".alpha")
+ default_alpha = torch.tensor(
+ sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
+ )
+ alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
@@ -404,7 +433,7 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
- ait_sd.update({k: down_weight for k in ait_down_keys})
+ ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -516,10 +545,95 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
f"transformer.single_transformer_blocks.{i}.norm.linear",
)
+ # TODO: alphas.
+ def assign_remaining_weights(assignments, source):
+ for lora_key in ["lora_A", "lora_B"]:
+ orig_lora_key = "lora_down" if lora_key == "lora_A" else "lora_up"
+ for target_fmt, source_fmt, transform in assignments:
+ target_key = target_fmt.format(lora_key=lora_key)
+ source_key = source_fmt.format(orig_lora_key=orig_lora_key)
+ value = source.pop(source_key)
+ if transform:
+ value = transform(value)
+ ait_sd[target_key] = value
+
+ if any("guidance_in" in k for k in sds_sd):
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_guidance_in_in_layer",
+ "time_text_embed.guidance_embedder.linear_1",
+ )
+
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_guidance_in_out_layer",
+ "time_text_embed.guidance_embedder.linear_2",
+ )
+
+ if any("img_in" in k for k in sds_sd):
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_img_in",
+ "x_embedder",
+ )
+
+ if any("txt_in" in k for k in sds_sd):
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_txt_in",
+ "context_embedder",
+ )
+
+ if any("time_in" in k for k in sds_sd):
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_time_in_in_layer",
+ "time_text_embed.timestep_embedder.linear_1",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_time_in_out_layer",
+ "time_text_embed.timestep_embedder.linear_2",
+ )
+
+ if any("vector_in" in k for k in sds_sd):
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_vector_in_in_layer",
+ "time_text_embed.text_embedder.linear_1",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_vector_in_out_layer",
+ "time_text_embed.text_embedder.linear_2",
+ )
+
+ if any("final_layer" in k for k in sds_sd):
+ # Notice the swap in processing for "final_layer".
+ assign_remaining_weights(
+ [
+ (
+ "norm_out.linear.{lora_key}.weight",
+ "lora_unet_final_layer_adaLN_modulation_1.{orig_lora_key}.weight",
+ swap_scale_shift,
+ ),
+ ("proj_out.{lora_key}.weight", "lora_unet_final_layer_linear.{orig_lora_key}.weight", None),
+ ],
+ sds_sd,
+ )
+
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
- if not all(k.startswith("lora_te") for k in remaining_keys):
+ if not all(k.startswith(("lora_te", "lora_te1")) for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
@@ -605,8 +719,25 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
elif k.startswith("lora_te1_"):
has_te_keys = True
continue
+ elif k.startswith("lora_transformer_context_embedder"):
+ diffusers_key = "context_embedder"
+ elif k.startswith("lora_transformer_norm_out_linear"):
+ diffusers_key = "norm_out.linear"
+ elif k.startswith("lora_transformer_proj_out"):
+ diffusers_key = "proj_out"
+ elif k.startswith("lora_transformer_x_embedder"):
+ diffusers_key = "x_embedder"
+ elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"):
+ i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1])
+ diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}"
+ elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"):
+ i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1])
+ diffusers_key = f"time_text_embed.text_embedder.linear_{i}"
+ elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"):
+ i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1])
+ diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}"
else:
- raise NotImplementedError
+ raise NotImplementedError(f"Handling for key ({k}) is not implemented.")
if "attn_" in k:
if "_to_out_0" in k:
@@ -678,12 +809,104 @@ def _convert(original_key, diffusers_key, state_dict, new_state_dict):
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
- state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
+ state_dict = {
+ k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
+ for k, v in state_dict.items()
+ if k.startswith("transformer.")
+ }
return state_dict
+
# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)
+
+ # ComfyUI.
+ if not has_mixture:
+ state_dict = {k.replace("diffusion_model.", "lora_unet_"): v for k, v in state_dict.items()}
+ state_dict = {k.replace("text_encoders.clip_l.transformer.", "lora_te_"): v for k, v in state_dict.items()}
+
+ has_position_embedding = any("position_embedding" in k for k in state_dict)
+ if has_position_embedding:
+ zero_status_pe = state_dict_all_zero(state_dict, "position_embedding")
+ if zero_status_pe:
+ logger.info(
+ "The `position_embedding` LoRA params are all zeros which make them ineffective. "
+ "So, we will purge them out of the current state dict to make loading possible."
+ )
+
+ else:
+ logger.info(
+ "The state_dict has position_embedding LoRA params and we currently do not support them. "
+ "Open an issue if you need this supported - https://github.com/huggingface/diffusers/issues/new."
+ )
+ state_dict = {k: v for k, v in state_dict.items() if "position_embedding" not in k}
+
+ has_t5xxl = any(k.startswith("text_encoders.t5xxl.transformer.") for k in state_dict)
+ if has_t5xxl:
+ zero_status_t5 = state_dict_all_zero(state_dict, "text_encoders.t5xxl")
+ if zero_status_t5:
+ logger.info(
+ "The `t5xxl` LoRA params are all zeros which make them ineffective. "
+ "So, we will purge them out of the current state dict to make loading possible."
+ )
+ else:
+ logger.info(
+ "T5-xxl keys found in the state dict, which are currently unsupported. We will filter them out."
+ "Open an issue if this is a problem - https://github.com/huggingface/diffusers/issues/new."
+ )
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
+
+ has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
+ if has_diffb:
+ zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
+ if zero_status_diff_b:
+ logger.info(
+ "The `diff_b` LoRA params are all zeros which make them ineffective. "
+ "So, we will purge them out of the current state dict to make loading possible."
+ )
+ else:
+ logger.info(
+ "`diff_b` keys found in the state dict which are currently unsupported. "
+ "So, we will filter out those keys. Open an issue if this is a problem - "
+ "https://github.com/huggingface/diffusers/issues/new."
+ )
+ state_dict = {k: v for k, v in state_dict.items() if ".diff_b" not in k}
+
+ has_norm_diff = any(".norm" in k and ".diff" in k for k in state_dict)
+ if has_norm_diff:
+ zero_status_diff = state_dict_all_zero(state_dict, ".diff")
+ if zero_status_diff:
+ logger.info(
+ "The `diff` LoRA params are all zeros which make them ineffective. "
+ "So, we will purge them out of the current state dict to make loading possible."
+ )
+ else:
+ logger.info(
+ "Normalization diff keys found in the state dict which are currently unsupported. "
+ "So, we will filter out those keys. Open an issue if this is a problem - "
+ "https://github.com/huggingface/diffusers/issues/new."
+ )
+ state_dict = {k: v for k, v in state_dict.items() if ".norm" not in k and ".diff" not in k}
+
+ limit_substrings = ["lora_down", "lora_up"]
+ if any("alpha" in k for k in state_dict):
+ limit_substrings.append("alpha")
+
+ state_dict = {
+ _custom_replace(k, limit_substrings): v
+ for k, v in state_dict.items()
+ if k.startswith(("lora_unet_", "lora_te_"))
+ }
+
+ if any("text_projection" in k for k in state_dict):
+ logger.info(
+ "`text_projection` keys found in the `state_dict` which are unexpected. "
+ "So, we will filter out those keys. Open an issue if this is a problem - "
+ "https://github.com/huggingface/diffusers/issues/new."
+ )
+ state_dict = {k: v for k, v in state_dict.items() if "text_projection" not in k}
+
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
@@ -713,7 +936,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
# down_weight is copied to each split
- ait_sd.update({k: down_weight for k in ait_down_keys})
+ ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -798,6 +1021,26 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
return new_state_dict
+def _custom_replace(key: str, substrings: List[str]) -> str:
+ # Replaces the "."s with "_"s upto the `substrings`.
+ # Example:
+ # lora_unet.foo.bar.lora_A.weight -> lora_unet_foo_bar.lora_A.weight
+ pattern = "(" + "|".join(re.escape(sub) for sub in substrings) + ")"
+
+ match = re.search(pattern, key)
+ if match:
+ start_sub = match.start()
+ if start_sub > 0 and key[start_sub - 1] == ".":
+ boundary = start_sub - 1
+ else:
+ boundary = start_sub
+ left = key[:boundary].replace(".", "_")
+ right = key[boundary:]
+ return left + right
+ else:
+ return key.replace(".", "_")
+
+
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
converted_state_dict = {}
original_state_dict_keys = list(original_state_dict.keys())
@@ -806,28 +1049,23 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
inner_dim = 3072
mlp_ratio = 4.0
- def swap_scale_shift(weight):
- shift, scale = weight.chunk(2, dim=0)
- new_weight = torch.cat([scale, shift], dim=0)
- return new_weight
-
for lora_key in ["lora_A", "lora_B"]:
## time_text_embed.timestep_embedder <- time_in
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"] = (
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
+ )
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
- ] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"] = (
+ original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
+ )
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"] = (
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
+ )
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
- ] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"] = (
+ original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
+ )
## time_text_embed.text_embedder <- vector_in
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
@@ -849,21 +1087,21 @@ def swap_scale_shift(weight):
# guidance
has_guidance = any("guidance" in k for k in original_state_dict)
if has_guidance:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"] = (
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
+ )
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
- ] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"] = (
+ original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
+ )
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"] = (
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
+ )
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
- converted_state_dict[
- f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
- ] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
+ converted_state_dict[f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"] = (
+ original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
+ )
# context_embedder
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
@@ -1012,7 +1250,7 @@ def swap_scale_shift(weight):
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
- # single transfomer blocks
+ # single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
@@ -1104,6 +1342,228 @@ def swap_scale_shift(weight):
return converted_state_dict
+def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
+ converted_state_dict = {}
+ original_state_dict_keys = list(original_state_dict.keys())
+ num_layers = 19
+ num_single_layers = 38
+ inner_dim = 3072
+ mlp_ratio = 4.0
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ original_block_prefix = "base_model.model."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norms
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
+ )
+
+ # Q, K, V
+ if lora_key == "lora_A":
+ sample_lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+
+ context_lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ else:
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
+ ),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
+
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
+ ),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
+
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
+
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
+
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
+ )
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
+ )
+
+ # single transformer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
+ )
+
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+
+ if lora_key == "lora_A":
+ lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
+
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
+ else:
+ q, k, v, mlp = torch.split(
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
+ split_size,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
+
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
+ split_size,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
+ )
+
+ for lora_key in ["lora_A", "lora_B"]:
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
+ )
+
+ if len(original_state_dict) > 0:
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
@@ -1354,50 +1814,669 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
converted_state_dict = {}
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
- num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
+ block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
+ min_block = min(block_numbers)
+ max_block = max(block_numbers)
+
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
+ lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
+ lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
+ has_time_projection_weight = any(
+ k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
+ )
- for i in range(num_blocks):
+ def get_alpha_scales(down_weight, alpha_key):
+ rank = down_weight.shape[0]
+ alpha = original_state_dict.pop(alpha_key).item()
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+ return scale_down, scale_up
+
+ for key in list(original_state_dict.keys()):
+ if key.endswith((".diff", ".diff_b")) and "norm" in key:
+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
+ # in future if needed and they are not zeroed.
+ original_state_dict.pop(key)
+ logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
+
+ if "time_projection" in key and not has_time_projection_weight:
+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
+ # CausVid lora has the weight keys and the bias keys.
+ original_state_dict.pop(key)
+
+ # For the `diff_b` keys, we treat them as lora_bias.
+ # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
+
+ for i in range(min_block, max_block + 1):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
- converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
- f"blocks.{i}.self_attn.{o}.lora_A.weight"
- )
- converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
- f"blocks.{i}.self_attn.{o}.lora_B.weight"
- )
+ alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
+ has_alpha = alpha_key in original_state_dict
+ original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
+ converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
+
+ original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
+ converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
+
+ if has_alpha:
+ down_weight = original_state_dict.pop(original_key_A)
+ up_weight = original_state_dict.pop(original_key_B)
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[converted_key_A] = down_weight * scale_down
+ converted_state_dict[converted_key_B] = up_weight * scale_up
+
+ else:
+ if original_key_A in original_state_dict:
+ converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
+ if original_key_B in original_state_dict:
+ converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
+
+ original_key = f"blocks.{i}.self_attn.{o}.diff_b"
+ converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
+ if original_key in original_state_dict:
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
- f"blocks.{i}.cross_attn.{o}.lora_A.weight"
- )
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
- f"blocks.{i}.cross_attn.{o}.lora_B.weight"
- )
+ alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
+ has_alpha = alpha_key in original_state_dict
+ original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
+ converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
+
+ original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
+ converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
+
+ if original_key_A in original_state_dict:
+ down_weight = original_state_dict.pop(original_key_A)
+ converted_state_dict[converted_key_A] = down_weight
+ if original_key_B in original_state_dict:
+ up_weight = original_state_dict.pop(original_key_B)
+ converted_state_dict[converted_key_B] = up_weight
+ if has_alpha:
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[converted_key_A] *= scale_down
+ converted_state_dict[converted_key_B] *= scale_up
+
+ original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
+ converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
+ if original_key in original_state_dict:
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
- f"blocks.{i}.cross_attn.{o}.lora_A.weight"
+ alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
+ has_alpha = alpha_key in original_state_dict
+ original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
+ converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
+
+ original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
+ converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
+
+ if original_key_A in original_state_dict:
+ down_weight = original_state_dict.pop(original_key_A)
+ converted_state_dict[converted_key_A] = down_weight
+ if original_key_B in original_state_dict:
+ up_weight = original_state_dict.pop(original_key_B)
+ converted_state_dict[converted_key_B] = up_weight
+ if has_alpha:
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[converted_key_A] *= scale_down
+ converted_state_dict[converted_key_B] *= scale_up
+
+ original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
+ converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
+ if original_key in original_state_dict:
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+
+ # FFN
+ for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
+ alpha_key = f"blocks.{i}.{o}.alpha"
+ has_alpha = alpha_key in original_state_dict
+ original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
+ converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
+
+ original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
+ converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
+
+ if original_key_A in original_state_dict:
+ down_weight = original_state_dict.pop(original_key_A)
+ converted_state_dict[converted_key_A] = down_weight
+ if original_key_B in original_state_dict:
+ up_weight = original_state_dict.pop(original_key_B)
+ converted_state_dict[converted_key_B] = up_weight
+ if has_alpha:
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[converted_key_A] *= scale_down
+ converted_state_dict[converted_key_B] *= scale_up
+
+ original_key = f"blocks.{i}.{o}.diff_b"
+ converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
+ if original_key in original_state_dict:
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+
+ # Remaining.
+ if original_state_dict:
+ if any("time_projection" in k for k in original_state_dict):
+ original_key = f"time_projection.1.{lora_down_key}.weight"
+ converted_key = "condition_embedder.time_proj.lora_A.weight"
+ if original_key in original_state_dict:
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+
+ original_key = f"time_projection.1.{lora_up_key}.weight"
+ converted_key = "condition_embedder.time_proj.lora_B.weight"
+ if original_key in original_state_dict:
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+
+ if "time_projection.1.diff_b" in original_state_dict:
+ converted_state_dict["condition_embedder.time_proj.lora_B.bias"] = original_state_dict.pop(
+ "time_projection.1.diff_b"
+ )
+
+ if any("head.head" in k for k in original_state_dict):
+ if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
+ converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
+ f"head.head.{lora_down_key}.weight"
)
- converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
- f"blocks.{i}.cross_attn.{o}.lora_B.weight"
+ if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
+ converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
+ f"head.head.{lora_up_key}.weight"
)
+ if "head.head.diff_b" in original_state_dict:
+ converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
+
+ # Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
+ # This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
+ # Since for this particular LoRA, we don't have the corresponding up matrix, I will use
+ # an identity.
+ if any("head.head" in k and k.endswith(".diff") for k in state_dict):
+ if f"head.head.{lora_down_key}.weight" in state_dict:
+ logger.info(
+ f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
+ )
+ converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
+ down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
+ up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
+ converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
+ *up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
+ ).T
+
+ for text_time in ["text_embedding", "time_embedding"]:
+ if any(text_time in k for k in original_state_dict):
+ for b_n in [0, 2]:
+ diffusers_b_n = 1 if b_n == 0 else 2
+ diffusers_name = (
+ "condition_embedder.text_embedder"
+ if text_time == "text_embedding"
+ else "condition_embedder.time_embedder"
+ )
+ if any(f"{text_time}.{b_n}" in k for k in original_state_dict):
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_A.weight"] = (
+ original_state_dict.pop(f"{text_time}.{b_n}.{lora_down_key}.weight")
+ )
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.weight"] = (
+ original_state_dict.pop(f"{text_time}.{b_n}.{lora_up_key}.weight")
+ )
+ if f"{text_time}.{b_n}.diff_b" in original_state_dict:
+ converted_state_dict[f"{diffusers_name}.linear_{diffusers_b_n}.lora_B.bias"] = (
+ original_state_dict.pop(f"{text_time}.{b_n}.diff_b")
+ )
+
+ for img_ours, img_theirs in [
+ ("ff.net.0.proj", "img_emb.proj.1"),
+ ("ff.net.2", "img_emb.proj.3"),
+ ]:
+ original_key = f"{img_theirs}.{lora_down_key}.weight"
+ converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_A.weight"
+ if original_key in original_state_dict:
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+
+ original_key = f"{img_theirs}.{lora_up_key}.weight"
+ converted_key = f"condition_embedder.image_embedder.{img_ours}.lora_B.weight"
+ if original_key in original_state_dict:
+ converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ bias_key_theirs = original_key.removesuffix(f".{lora_up_key}.weight") + ".diff_b"
+ if bias_key_theirs in original_state_dict:
+ bias_key = converted_key.removesuffix(".weight") + ".bias"
+ converted_state_dict[bias_key] = original_state_dict.pop(bias_key_theirs)
+
+ if len(original_state_dict) > 0:
+ diff = all(".diff" in k for k in original_state_dict)
+ if diff:
+ diff_keys = {k for k in original_state_dict if k.endswith(".diff")}
+ if not all("lora" not in k for k in diff_keys):
+ raise ValueError
+ logger.info(
+ "The remaining `state_dict` contains `diff` keys which we do not handle yet. If you see performance issues, please file an issue: "
+ "https://github.com/huggingface/diffusers//issues/new"
+ )
+ else:
+ raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
+def _convert_musubi_wan_lora_to_diffusers(state_dict):
+ # https://github.com/kohya-ss/musubi-tuner
+ converted_state_dict = {}
+ original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
+
+ num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
+ is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
+
+ def get_alpha_scales(down_weight, key):
+ rank = down_weight.shape[0]
+ alpha = original_state_dict.pop(key + ".alpha").item()
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+ return scale_down, scale_up
+
+ for i in range(num_blocks):
+ # Self-attention
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
+ down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
+ up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
+ converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
+
+ # Cross-attention
+ for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
+ down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
+ up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
+
+ if is_i2v_lora:
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
+ down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
+ up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
+ converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
# FFN
- for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
- converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = original_state_dict.pop(
- f"blocks.{i}.{o}.lora_A.weight"
+ for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
+ down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
+ up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
+ scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
+ converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
+
+ if len(original_state_dict) > 0:
+ raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
+def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
+ if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
+ raise ValueError("Invalid LoRA state dict for HiDream.")
+ converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
+ return converted_state_dict
+
+
+def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
+ if not all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict):
+ raise ValueError("Invalid LoRA state dict for LTX-Video.")
+ converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
+ return converted_state_dict
+
+
+def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
+ has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
+ if has_diffusion_model:
+ state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
+
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
+ if has_lora_unet:
+ state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
+
+ def convert_key(key: str) -> str:
+ prefix = "transformer_blocks"
+ if "." in key:
+ base, suffix = key.rsplit(".", 1)
+ else:
+ base, suffix = key, ""
+
+ start = f"{prefix}_"
+ rest = base[len(start) :]
+
+ if "." in rest:
+ head, tail = rest.split(".", 1)
+ tail = "." + tail
+ else:
+ head, tail = rest, ""
+
+ # Protected n-grams that must keep their internal underscores
+ protected = {
+ # pairs
+ ("to", "q"),
+ ("to", "k"),
+ ("to", "v"),
+ ("to", "out"),
+ ("add", "q"),
+ ("add", "k"),
+ ("add", "v"),
+ ("txt", "mlp"),
+ ("img", "mlp"),
+ ("txt", "mod"),
+ ("img", "mod"),
+ # triplets
+ ("add", "q", "proj"),
+ ("add", "k", "proj"),
+ ("add", "v", "proj"),
+ ("to", "add", "out"),
+ }
+
+ prot_by_len = {}
+ for ng in protected:
+ prot_by_len.setdefault(len(ng), set()).add(ng)
+
+ parts = head.split("_")
+ merged = []
+ i = 0
+ lengths_desc = sorted(prot_by_len.keys(), reverse=True)
+
+ while i < len(parts):
+ matched = False
+ for L in lengths_desc:
+ if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
+ merged.append("_".join(parts[i : i + L]))
+ i += L
+ matched = True
+ break
+ if not matched:
+ merged.append(parts[i])
+ i += 1
+
+ head_converted = ".".join(merged)
+ converted_base = f"{prefix}.{head_converted}{tail}"
+ return converted_base + (("." + suffix) if suffix else "")
+
+ state_dict = {convert_key(k): v for k, v in state_dict.items()}
+
+ has_default = any("default." in k for k in state_dict)
+ if has_default:
+ state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
+
+ converted_state_dict = {}
+ all_keys = list(state_dict.keys())
+ down_key = ".lora_down.weight"
+ up_key = ".lora_up.weight"
+ a_key = ".lora_A.weight"
+ b_key = ".lora_B.weight"
+
+ has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
+ has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
+
+ if has_non_diffusers_lora_id:
+
+ def get_alpha_scales(down_weight, alpha_key):
+ rank = down_weight.shape[0]
+ alpha = state_dict.pop(alpha_key).item()
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+ return scale_down, scale_up
+
+ for k in all_keys:
+ if k.endswith(down_key):
+ diffusers_down_key = k.replace(down_key, ".lora_A.weight")
+ diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
+ alpha_key = k.replace(down_key, ".alpha")
+
+ down_weight = state_dict.pop(k)
+ up_weight = state_dict.pop(k.replace(down_key, up_key))
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[diffusers_down_key] = down_weight * scale_down
+ converted_state_dict[diffusers_up_key] = up_weight * scale_up
+
+ # Already in diffusers format (lora_A/lora_B), just pop
+ elif has_diffusers_lora_id:
+ for k in all_keys:
+ if a_key in k or b_key in k:
+ converted_state_dict[k] = state_dict.pop(k)
+ elif ".alpha" in k:
+ state_dict.pop(k)
+
+ if len(state_dict) > 0:
+ raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
+
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
+ return converted_state_dict
+
+
+def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
+ converted_state_dict = {}
+
+ prefix = "diffusion_model."
+ original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
+
+ num_double_layers = 8
+ num_single_layers = 48
+ lora_keys = ("lora_A", "lora_B")
+ attn_types = ("img_attn", "txt_attn")
+
+ for sl in range(num_single_layers):
+ single_block_prefix = f"single_blocks.{sl}"
+ attn_prefix = f"single_transformer_blocks.{sl}.attn"
+
+ for lora_key in lora_keys:
+ converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{single_block_prefix}.linear1.{lora_key}.weight"
)
- converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = original_state_dict.pop(
- f"blocks.{i}.{o}.lora_B.weight"
+
+ converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{single_block_prefix}.linear2.{lora_key}.weight"
)
+ for dl in range(num_double_layers):
+ transformer_block_prefix = f"transformer_blocks.{dl}"
+
+ for lora_key in lora_keys:
+ for attn_type in attn_types:
+ attn_prefix = f"{transformer_block_prefix}.attn"
+ qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
+ fused_qkv_weight = original_state_dict.pop(qkv_key)
+
+ if lora_key == "lora_A":
+ diff_attn_proj_keys = (
+ ["to_q", "to_k", "to_v"]
+ if attn_type == "img_attn"
+ else ["add_q_proj", "add_k_proj", "add_v_proj"]
+ )
+ for proj_key in diff_attn_proj_keys:
+ converted_state_dict[f"{attn_prefix}.{proj_key}.{lora_key}.weight"] = torch.cat(
+ [fused_qkv_weight]
+ )
+ else:
+ sample_q, sample_k, sample_v = torch.chunk(fused_qkv_weight, 3, dim=0)
+
+ if attn_type == "img_attn":
+ converted_state_dict[f"{attn_prefix}.to_q.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{attn_prefix}.to_k.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{attn_prefix}.to_v.{lora_key}.weight"] = torch.cat([sample_v])
+ else:
+ converted_state_dict[f"{attn_prefix}.add_q_proj.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{attn_prefix}.add_k_proj.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{attn_prefix}.add_v_proj.{lora_key}.weight"] = torch.cat([sample_v])
+
+ proj_mappings = [
+ ("img_attn.proj", "attn.to_out.0"),
+ ("txt_attn.proj", "attn.to_add_out"),
+ ]
+ for org_proj, diff_proj in proj_mappings:
+ for lora_key in lora_keys:
+ original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
+ diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
+ converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
+
+ mlp_mappings = [
+ ("img_mlp.0", "ff.linear_in"),
+ ("img_mlp.2", "ff.linear_out"),
+ ("txt_mlp.0", "ff_context.linear_in"),
+ ("txt_mlp.2", "ff_context.linear_out"),
+ ]
+ for org_mlp, diff_mlp in mlp_mappings:
+ for lora_key in lora_keys:
+ original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
+ diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
+ converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
+
if len(original_state_dict) > 0:
- raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
+
+
+def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
+ """
+ Convert non-diffusers ZImage LoRA state dict to diffusers format.
+
+ Handles:
+ - `diffusion_model.` prefix removal
+ - `lora_unet_` prefix conversion with key mapping
+ - `default.` prefix removal
+ - `.lora_down.weight`/`.lora_up.weight` → `.lora_A.weight`/`.lora_B.weight` conversion with alpha scaling
+ """
+ has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
+ if has_diffusion_model:
+ state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
+
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
+ if has_lora_unet:
+ state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
+
+ def convert_key(key: str) -> str:
+ # ZImage has: layers, noise_refiner, context_refiner blocks
+ # Keys may be like: layers_0_attention_to_q.lora_down.weight
+
+ if "." in key:
+ base, suffix = key.rsplit(".", 1)
+ else:
+ base, suffix = key, ""
+
+ # Protected n-grams that must keep their internal underscores
+ protected = {
+ # pairs for attention
+ ("to", "q"),
+ ("to", "k"),
+ ("to", "v"),
+ ("to", "out"),
+ # feed_forward
+ ("feed", "forward"),
+ }
+
+ prot_by_len = {}
+ for ng in protected:
+ prot_by_len.setdefault(len(ng), set()).add(ng)
+
+ parts = base.split("_")
+ merged = []
+ i = 0
+ lengths_desc = sorted(prot_by_len.keys(), reverse=True)
+
+ while i < len(parts):
+ matched = False
+ for L in lengths_desc:
+ if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
+ merged.append("_".join(parts[i : i + L]))
+ i += L
+ matched = True
+ break
+ if not matched:
+ merged.append(parts[i])
+ i += 1
+
+ converted_base = ".".join(merged)
+ return converted_base + (("." + suffix) if suffix else "")
+
+ state_dict = {convert_key(k): v for k, v in state_dict.items()}
+
+ def normalize_out_key(k: str) -> str:
+ if ".to_out" in k:
+ return k
+ return re.sub(
+ r"\.out(?=\.(?:lora_down|lora_up)\.weight$|\.alpha$)",
+ ".to_out.0",
+ k,
+ )
+
+ state_dict = {normalize_out_key(k): v for k, v in state_dict.items()}
+
+ has_default = any("default." in k for k in state_dict)
+ if has_default:
+ state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
+
+ converted_state_dict = {}
+ all_keys = list(state_dict.keys())
+ down_key = ".lora_down.weight"
+ up_key = ".lora_up.weight"
+ a_key = ".lora_A.weight"
+ b_key = ".lora_B.weight"
+
+ has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
+ has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
+
+ if has_non_diffusers_lora_id:
+
+ def get_alpha_scales(down_weight, alpha_key):
+ rank = down_weight.shape[0]
+ alpha = state_dict.pop(alpha_key).item()
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+ return scale_down, scale_up
+
+ for k in all_keys:
+ if k.endswith(down_key):
+ diffusers_down_key = k.replace(down_key, ".lora_A.weight")
+ diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
+ alpha_key = k.replace(down_key, ".alpha")
+
+ down_weight = state_dict.pop(k)
+ up_weight = state_dict.pop(k.replace(down_key, up_key))
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[diffusers_down_key] = down_weight * scale_down
+ converted_state_dict[diffusers_up_key] = up_weight * scale_up
+
+ # Already in diffusers format (lora_A/lora_B), just pop
+ elif has_diffusers_lora_id:
+ for k in all_keys:
+ if a_key in k or b_key in k:
+ converted_state_dict[k] = state_dict.pop(k)
+ elif ".alpha" in k:
+ state_dict.pop(k)
+
+ if len(state_dict) > 0:
+ raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
+
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
+ return converted_state_dict
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index e522778deeed..bcbe54649f89 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,6 +22,8 @@
USE_PEFT_BACKEND,
deprecate,
get_submodule_by_name,
+ is_bitsandbytes_available,
+ is_gguf_available,
is_peft_available,
is_peft_version,
is_torch_version,
@@ -35,14 +37,22 @@
LoraBaseMixin,
_fetch_state_dict,
_load_lora_into_text_encoder,
+ _pack_dict_with_prefix,
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
+ _convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
+ _convert_musubi_wan_lora_to_diffusers,
+ _convert_non_diffusers_flux2_lora_to_diffusers,
+ _convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
+ _convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
+ _convert_non_diffusers_qwen_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
+ _convert_non_diffusers_z_image_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
@@ -68,6 +78,55 @@
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
+def _maybe_dequantize_weight_for_expanded_lora(model, module):
+ if is_bitsandbytes_available():
+ from ..quantizers.bitsandbytes import dequantize_bnb_weight
+
+ if is_gguf_available():
+ from ..quantizers.gguf.utils import dequantize_gguf_tensor
+
+ is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit"
+ is_bnb_8bit_quantized = module.weight.__class__.__name__ == "Int8Params"
+ is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter"
+
+ if is_bnb_4bit_quantized and not is_bitsandbytes_available():
+ raise ValueError(
+ "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints."
+ )
+ if is_bnb_8bit_quantized and not is_bitsandbytes_available():
+ raise ValueError(
+ "The checkpoint seems to have been quantized with `bitsandbytes` (8bits). Install `bitsandbytes` to load quantized checkpoints."
+ )
+ if is_gguf_quantized and not is_gguf_available():
+ raise ValueError(
+ "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints."
+ )
+
+ weight_on_cpu = False
+ if module.weight.device.type == "cpu":
+ weight_on_cpu = True
+
+ device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
+ if is_bnb_4bit_quantized or is_bnb_8bit_quantized:
+ module_weight = dequantize_bnb_weight(
+ module.weight.to(device) if weight_on_cpu else module.weight,
+ state=module.weight.quant_state if is_bnb_4bit_quantized else module.state,
+ dtype=model.dtype,
+ ).data
+ elif is_gguf_quantized:
+ module_weight = dequantize_gguf_tensor(
+ module.weight.to(device) if weight_on_cpu else module.weight,
+ )
+ module_weight = module_weight.to(model.dtype)
+ else:
+ module_weight = module.weight.data
+
+ if weight_on_cpu:
+ module_weight = module_weight.cpu()
+
+ return module_weight
+
+
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
@@ -79,10 +138,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_name = TEXT_ENCODER_NAME
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
- """
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
+ """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
@@ -105,6 +167,29 @@ def load_lora_weights(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
+
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
+ to call an additional method before loading the adapter:
+
+ ```py
+ pipeline = ... # load diffusers pipeline
+ max_rank = ... # the highest rank among all LoRAs that you want to load
+ # call *before* compiling and loading the LoRA adapter
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
+ pipeline.load_lora_weights(file_name)
+ # optionally compile the model now
+ ```
+
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
+ limitations to this technique, which are documented here:
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
@@ -122,7 +207,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -133,8 +219,10 @@ def load_lora_weights(
network_alphas=network_alphas,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -145,7 +233,9 @@ def load_lora_weights(
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
+ metadata=metadata,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -158,13 +248,8 @@ def lora_state_dict(
r"""
Return state dict for lora weights and the network alphas.
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
+ > [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is
+ experimental and might change in the future.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -200,6 +285,8 @@ def lora_state_dict(
The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
+ return_lora_metadata (`bool`, *optional*, defaults to False):
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
@@ -213,18 +300,16 @@ def lora_state_dict(
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -261,11 +346,20 @@ def lora_state_dict(
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
- return state_dict, network_alphas
+ out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
+ return out
@classmethod
def load_lora_into_unet(
- cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ network_alphas,
+ unet,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -287,6 +381,11 @@ def load_lora_into_unet(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
+ from the state dict.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -305,8 +404,10 @@ def load_lora_into_unet(
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -320,6 +421,8 @@ def load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -345,6 +448,11 @@ def load_lora_into_text_encoder(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
+ from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -354,8 +462,10 @@ def load_lora_into_text_encoder(
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -368,6 +478,8 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ unet_lora_adapter_metadata=None,
+ text_encoder_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -390,22 +502,29 @@ def save_lora_weights(
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ unet_lora_adapter_metadata:
+ LoRA adapter metadata associated with the unet to be serialized with the state dict.
+ text_encoder_lora_adapter_metadata:
+ LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
- state_dict = {}
-
- if not (unet_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+ lora_layers[cls.unet_name] = unet_lora_layers
+ lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+ lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
+ lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -423,11 +542,7 @@ def fuse_lora(
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
@@ -464,11 +579,7 @@ def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -495,34 +606,11 @@ def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
+ hotswap: bool = False,
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
- loaded into `self.unet`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
- dict is loaded into `self.text_encoder`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -542,7 +630,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict, network_alphas = self.lora_state_dict(
+ kwargs["return_lora_metadata"] = True
+ state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
@@ -557,8 +646,10 @@ def load_lora_weights(
network_alphas=network_alphas,
unet=self.unet,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -567,8 +658,10 @@ def load_lora_weights(
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -577,8 +670,10 @@ def load_lora_weights(
prefix=f"{self.text_encoder_name}_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -592,13 +687,8 @@ def lora_state_dict(
r"""
Return state dict for lora weights and the network alphas.
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
+ > [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is
+ experimental and might change in the future.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -634,6 +724,8 @@ def lora_state_dict(
The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
+ return_lora_metadata (`bool`, *optional*, defaults to False):
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
@@ -647,18 +739,16 @@ def lora_state_dict(
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -695,12 +785,21 @@ def lora_state_dict(
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
- return state_dict, network_alphas
+ out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
+ return out
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
def load_lora_into_unet(
- cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ network_alphas,
+ unet,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -722,6 +821,11 @@ def load_lora_into_unet(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
+ from the state dict.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -740,8 +844,10 @@ def load_lora_into_unet(
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -756,6 +862,8 @@ def load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -781,6 +889,11 @@ def load_lora_into_text_encoder(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
+ from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -790,8 +903,10 @@ def load_lora_into_text_encoder(
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -805,51 +920,37 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ unet_lora_adapter_metadata=None,
+ text_encoder_lora_adapter_metadata=None,
+ text_encoder_2_lora_adapter_metadata=None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `unet`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
- raise ValueError(
- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
- )
+ lora_layers = {}
+ lora_metadata = {}
if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+ lora_layers[cls.unet_name] = unet_lora_layers
+ lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+ lora_layers["text_encoder"] = text_encoder_lora_layers
+ lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+ lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
+ lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError(
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
+ )
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -865,35 +966,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -905,21 +978,7 @@ def fuse_lora(
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -945,49 +1004,7 @@ def lora_state_dict(
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -1000,18 +1017,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -1032,34 +1047,18 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- return state_dict
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name=None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1075,7 +1074,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -1085,8 +1085,10 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -1095,8 +1097,10 @@ def load_lora_weights(
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -1105,30 +1109,25 @@ def load_lora_weights(
prefix=f"{self.text_encoder_name}_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SD3Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -1141,8 +1140,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -1157,6 +1158,8 @@ def load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1182,6 +1185,11 @@ def load_lora_into_text_encoder(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
+ from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -1191,8 +1199,10 @@ def load_lora_into_text_encoder(
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -1207,51 +1217,37 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata=None,
+ text_encoder_lora_adapter_metadata=None,
+ text_encoder_2_lora_adapter_metadata=None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
- raise ValueError(
- "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
- )
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+ lora_layers["text_encoder"] = text_encoder_lora_layers
+ lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+ lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
+ lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError(
+ "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
+ )
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -1268,35 +1264,207 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
-
+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
- This is an experimental API.
-
+class AuraFlowLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
+ """
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
- Example:
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
- ```py
- from diffusers import DiffusionPipeline
- import torch
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -1306,24 +1474,10 @@ def fuse_lora(
**kwargs,
)
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -1350,49 +1504,7 @@ def lora_state_dict(
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -1405,18 +1517,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -1441,18 +1551,47 @@ def lora_state_dict(
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
- return (state_dict, None) if return_alphas else state_dict
+ return cls._prepare_outputs(
+ state_dict,
+ metadata=metadata,
+ alphas=None,
+ return_alphas=return_alphas,
+ return_metadata=return_lora_metadata,
+ )
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
- return (state_dict, None) if return_alphas else state_dict
+ return cls._prepare_outputs(
+ state_dict,
+ metadata=metadata,
+ alphas=None,
+ return_alphas=return_alphas,
+ return_metadata=return_lora_metadata,
+ )
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
- return (state_dict, None) if return_alphas else state_dict
+ return cls._prepare_outputs(
+ state_dict,
+ metadata=metadata,
+ alphas=None,
+ return_alphas=return_alphas,
+ return_metadata=return_lora_metadata,
+ )
+
+ is_fal_kontext = any("base_model" in k for k in state_dict)
+ if is_fal_kontext:
+ state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
+ return cls._prepare_outputs(
+ state_dict,
+ metadata=metadata,
+ alphas=None,
+ return_alphas=return_alphas,
+ return_metadata=return_lora_metadata,
+ )
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
@@ -1470,13 +1609,23 @@ def lora_state_dict(
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)
- if return_alphas:
- return state_dict, network_alphas
+ if return_alphas or return_lora_metadata:
+ return cls._prepare_outputs(
+ state_dict,
+ metadata=metadata,
+ alphas=network_alphas,
+ return_alphas=return_alphas,
+ return_metadata=return_lora_metadata,
+ )
else:
return state_dict
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
@@ -1493,14 +1642,16 @@ def load_lora_weights(
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1516,7 +1667,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict, network_alphas = self.lora_state_dict(
+ kwargs["return_lora_metadata"] = True
+ state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
@@ -1567,8 +1719,10 @@ def load_lora_weights(
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
if len(transformer_norm_state_dict) > 0:
@@ -1585,34 +1739,26 @@ def load_lora_weights(
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
def load_lora_into_transformer(
- cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ network_alphas,
+ transformer,
+ adapter_name=None,
+ metadata=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`FluxTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
@@ -1625,8 +1771,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -1641,7 +1789,7 @@ def _load_norm_into_transformer(
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
- state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
+ state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Find invalid keys
transformer_state_dict = transformer.state_dict()
@@ -1695,6 +1843,8 @@ def load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1720,6 +1870,11 @@ def load_lora_into_text_encoder(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
+ from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -1729,8 +1884,10 @@ def load_lora_into_text_encoder(
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -1744,6 +1901,8 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata=None,
+ text_encoder_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
@@ -1766,22 +1925,29 @@ def save_lora_weights(
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ transformer_lora_adapter_metadata:
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ text_encoder_lora_adapter_metadata:
+ LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
- state_dict = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+ lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
+ lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -1797,35 +1963,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
@@ -1853,11 +1991,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -1947,7 +2081,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
) -> bool:
"""
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
- generalizes things a bit so that any parameter that needs expansion receives appropriate treatement.
+ generalizes things a bit so that any parameter that needs expansion receives appropriate treatment.
"""
state_dict = {}
if lora_state_dict is not None:
@@ -1959,13 +2093,14 @@ def _maybe_expand_transformer_param_shape_or_error_(
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
- state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
+ state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
overwritten_params = {}
is_peft_loaded = getattr(transformer, "peft_config", None) is not None
+ is_quantized = hasattr(transformer, "hf_quantizer")
for name, module in transformer.named_modules():
if isinstance(module, torch.nn.Linear):
module_weight = module.weight.data
@@ -1990,9 +2125,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
if tuple(module_weight_shape) == (out_features, in_features):
continue
- # TODO (sayakpaul): We still need to consider if the module we're expanding is
- # quantized and handle it accordingly if that is the case.
- module_out_features, module_in_features = module_weight.shape
+ module_out_features, module_in_features = module_weight_shape
debug_message = ""
if in_features > module_in_features:
debug_message += (
@@ -2015,6 +2148,10 @@ def _maybe_expand_transformer_param_shape_or_error_(
parent_module_name, _, current_module_name = name.rpartition(".")
parent_module = transformer.get_submodule(parent_module_name)
+ if is_quantized:
+ module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module)
+
+ # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True.
with torch.device("meta"):
expanded_module = torch.nn.Linear(
in_features, out_features, bias=bias, dtype=module_weight.dtype
@@ -2026,7 +2163,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
new_weight = torch.zeros_like(
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
)
- slices = tuple(slice(0, dim) for dim in module_weight.shape)
+ slices = tuple(slice(0, dim) for dim in module_weight_shape)
new_weight[slices] = module_weight
tmp_state_dict = {"weight": new_weight}
if module_bias is not None:
@@ -2074,14 +2211,13 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
if unexpected_modules:
logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.")
- is_peft_loaded = getattr(transformer, "peft_config", None) is not None
for k in lora_module_names:
if k in unexpected_modules:
continue
base_param_name = (
f"{k.replace(prefix, '')}.base_layer.weight"
- if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
+ if f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
@@ -2115,7 +2251,12 @@ def _calculate_module_shape(
base_weight_param_name: str = None,
) -> "torch.Size":
def _get_weight_shape(weight: torch.Tensor):
- return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape
+ if weight.__class__.__name__ == "Params4bit":
+ return weight.quant_state.shape
+ elif weight.__class__.__name__ == "GGUFParameter":
+ return weight.quant_shape
+ else:
+ return weight.shape
if base_module is not None:
return _get_weight_shape(base_module.weight)
@@ -2130,6 +2271,15 @@ def _get_weight_shape(weight: torch.Tensor):
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
+ @staticmethod
+ def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
+ outputs = [state_dict]
+ if return_alphas:
+ outputs.append(alphas)
+ if return_metadata:
+ outputs.append(metadata)
+ return tuple(outputs) if (return_alphas or return_metadata) else state_dict
+
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.
@@ -2141,28 +2291,18 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
def load_lora_into_transformer(
- cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ network_alphas,
+ transformer,
+ adapter_name=None,
+ metadata=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`UVit2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
@@ -2175,8 +2315,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -2191,6 +2333,8 @@ def load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -2216,6 +2360,11 @@ def load_lora_into_text_encoder(
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
+ from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
@@ -2225,8 +2374,10 @@ def load_lora_into_text_encoder(
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -2301,49 +2452,7 @@ def lora_state_dict(
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -2356,18 +2465,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -2388,29 +2495,18 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- return state_dict
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -2426,7 +2522,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -2436,31 +2533,26 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`CogVideoXTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -2473,12 +2565,13 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
- # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
@@ -2487,38 +2580,25 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -2534,35 +2614,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -2574,18 +2626,7 @@ def fuse_lora(
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -2607,117 +2648,62 @@ def lora_state_dict(
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
-
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- This function is experimental and might change in the future.
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
-
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
- """
- # Load the main state dict first which has the LoRA layers for either of
- # transformer and text encoder or both.
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- local_files_only = kwargs.pop("local_files_only", None)
- token = kwargs.pop("token", None)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- weight_name = kwargs.pop("weight_name", None)
- use_safetensors = kwargs.pop("use_safetensors", None)
-
- allow_pickle = False
- if use_safetensors is None:
- use_safetensors = True
- allow_pickle = True
-
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
- pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
- weight_name=weight_name,
- use_safetensors=use_safetensors,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- subfolder=subfolder,
- user_agent=user_agent,
- allow_pickle=allow_pickle,
- )
-
- is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -2733,7 +2719,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -2743,31 +2730,26 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`MochiTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -2780,8 +2762,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -2794,38 +2778,25 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -2842,35 +2813,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -2883,18 +2826,7 @@ def fuse_lora(
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -2909,56 +2841,13 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -2971,18 +2860,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -3003,30 +2890,23 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- return state_dict
+ is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
+ if is_non_diffusers_format:
+ state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict)
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3042,7 +2922,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -3052,31 +2933,26 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`LTXVideoTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3089,8 +2965,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -3103,38 +2981,25 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- """
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -3151,35 +3016,207 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
-
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
- This is an experimental API.
-
+class SanaLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
+ """
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
- Example:
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
- ```py
- from diffusers import DiffusionPipeline
- import torch
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -3192,25 +3229,1257 @@ def fuse_lora(
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
-
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- This is an experimental API.
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
-
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
+ if is_original_hunyuan_video:
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class Lumina2LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ # conversion.
+ non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
+ if non_diffusers:
+ state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class KandinskyLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`Kandinsky5Transformer3DModel`],
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+ - A string, the *model id* of a pretrained model hosted on the Hub.
+ - A path to a *directory* containing the model weights.
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository.
+ weight_name (`str`, *optional*, defaults to None):
+ Name of the serialized state dict file.
+ use_safetensors (`bool`, *optional*):
+ Whether to use safetensors for loading.
+ return_lora_metadata (`bool`, *optional*, defaults to False):
+ When enabled, additionally return the LoRA adapter metadata.
+ """
+ # Load the main state dict first which has the LoRA layers
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model.
+ hotswap (`bool`, *optional*):
+ Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ # Load LoRA into transformer
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ Load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters.
+ transformer (`Kandinsky5Transformer3DModel`):
+ The transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata.
+ """
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata=None,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the transformer and text encoders.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process.
+ save_function (`Callable`):
+ The function to use to save the state dictionary.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way.
+ transformer_lora_adapter_metadata:
+ LoRA adapter metadata associated with the transformer.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers`")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing.
+
+ Example:
+ ```py
+ from diffusers import Kandinsky5T2VPipeline
+
+ pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
+ pipeline.load_lora_weights("path/to/lora.safetensors")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of [`pipe.fuse_lora()`].
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class WanLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
+ """
+
+ _lora_loadable_modules = ["transformer", "transformer_2"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ if any(k.startswith("diffusion_model.") for k in state_dict):
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
+ elif any(k.startswith("lora_unet_") for k in state_dict):
+ state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ @classmethod
+ def _maybe_expand_t2v_lora_for_i2v(
+ cls,
+ transformer: torch.nn.Module,
+ state_dict,
+ ):
+ if transformer.config.image_dim is None:
+ return state_dict
+
+ target_device = transformer.device
+
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
+ has_bias = any(".lora_B.bias" in k for k in state_dict)
+
+ if is_i2v_lora:
+ return state_dict
+
+ for i in range(num_blocks):
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
+ # These keys should exist if the block `i` was part of the T2V LoRA.
+ ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
+ ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
+
+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
+ continue
+
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
+ )
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
+ )
+
+ # If the original LoRA had biases (indicated by has_bias)
+ # AND the specific reference bias key exists for this block.
+
+ ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
+ if has_bias and ref_key_lora_B_bias in state_dict:
+ ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
+ ref_lora_B_bias_tensor,
+ device=target_device,
+ )
+
+ return state_dict
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ state_dict=state_dict,
+ )
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
+ if load_into_transformer_2:
+ if not hasattr(self, "transformer_2"):
+ raise AttributeError(
+ f"'{type(self).__name__}' object has no attribute transformer_2"
+ "Note that Wan2.1 models do not have a transformer_2 component."
+ "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
+ )
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=self.transformer_2,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ else:
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name)
+ if not hasattr(self, "transformer")
+ else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
+class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`SkyReelsV2Transformer3DModel`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ if any(k.startswith("diffusion_model.") for k in state_dict):
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
+ elif any(k.startswith("lora_unet_") for k in state_dict):
+ state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v
+ def _maybe_expand_t2v_lora_for_i2v(
+ cls,
+ transformer: torch.nn.Module,
+ state_dict,
+ ):
+ if transformer.config.image_dim is None:
+ return state_dict
+
+ target_device = transformer.device
+
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
+ has_bias = any(".lora_B.bias" in k for k in state_dict)
+
+ if is_i2v_lora:
+ return state_dict
+
+ for i in range(num_blocks):
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
+ # These keys should exist if the block `i` was part of the T2V LoRA.
+ ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
+ ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
+
+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
+ continue
+
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
+ )
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
+ )
+
+ # If the original LoRA had biases (indicated by has_bias)
+ # AND the specific reference bias key exists for this block.
+
+ ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
+ if has_bias and ref_key_lora_B_bias in state_dict:
+ ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
+ ref_lora_B_bias_tensor,
+ device=target_device,
+ )
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ state_dict=state_dict,
+ )
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
+ if load_into_transformer_2:
+ if not hasattr(self, "transformer_2"):
+ raise AttributeError(
+ f"'{type(self).__name__}' object has no attribute transformer_2"
+ "Note that Wan2.1 models do not have a transformer_2 component."
+ "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
+ )
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=self.transformer_2,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ else:
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name)
+ if not hasattr(self, "transformer")
+ else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
-class SanaLoraLoaderMixin(LoraBaseMixin):
+class CogView4LoraLoaderMixin(LoraBaseMixin):
r"""
- Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
"""
_lora_loadable_modules = ["transformer"]
@@ -3218,56 +4487,14 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3280,18 +4507,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -3312,30 +4537,19 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- return state_dict
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3351,7 +4565,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -3361,31 +4576,26 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SanaTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3398,8 +4608,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -3412,38 +4624,25 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -3460,35 +4659,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -3501,25 +4672,14 @@ def fuse_lora(
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
-class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
+class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
r"""
- Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
+ Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
@@ -3533,49 +4693,7 @@ def lora_state_dict(
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading original format HunyuanVideo LoRA checkpoints.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3588,18 +4706,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -3620,34 +4736,23 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
- if is_original_hunyuan_video:
- state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
+ is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
+ if is_non_diffusers_format:
+ state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
- return state_dict
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3663,7 +4768,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -3673,31 +4779,26 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`HunyuanVideoTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3710,8 +4811,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -3724,45 +4827,32 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
@@ -3772,35 +4862,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -3810,28 +4872,17 @@ def fuse_lora(
**kwargs,
)
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
-class Lumina2LoraLoaderMixin(LoraBaseMixin):
+class QwenImageLoraLoaderMixin(LoraBaseMixin):
r"""
- Load LoRA layers into [`Lumina2Transformer2DModel`]. Specific to [`Lumina2Text2ImgPipeline`].
+ Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
@@ -3845,49 +4896,7 @@ def lora_state_dict(
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3900,18 +4909,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -3932,35 +4939,26 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- # conversion.
- non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
- if non_diffusers:
- state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
+ has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
+ has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
+ has_default = any("default." in k for k in state_dict)
+ if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
+ state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
- return state_dict
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3976,7 +4974,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -3986,31 +4985,26 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`Lumina2Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4023,8 +5017,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -4037,45 +5033,32 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
@@ -4085,35 +5068,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4123,28 +5078,17 @@ def fuse_lora(
**kwargs,
)
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
-class WanLoraLoaderMixin(LoraBaseMixin):
+class ZImageLoraLoaderMixin(LoraBaseMixin):
r"""
- Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
+ Load LoRA layers into [`ZImageTransformer2DModel`]. Specific to [`ZImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
@@ -4158,49 +5102,7 @@ def lora_state_dict(
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -4213,18 +5115,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -4238,65 +5138,33 @@ def lora_state_dict(
user_agent=user_agent,
allow_pickle=allow_pickle,
)
- if any(k.startswith("diffusion_model.") for k in state_dict):
- state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
- if is_dora_scale_present:
- warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
- logger.warning(warn_msg)
- state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
-
- return state_dict
-
- @classmethod
- def _maybe_expand_t2v_lora_for_i2v(
- cls,
- transformer: torch.nn.Module,
- state_dict,
- ):
- if transformer.config.image_dim is None:
- return state_dict
-
- if any(k.startswith("transformer.blocks.") for k in state_dict):
- num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict})
- is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
-
- if is_i2v_lora:
- return state_dict
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- for i in range(num_blocks):
- for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
- state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
- state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"]
- )
- state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
- state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"]
- )
+ has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
+ has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
+ has_default = any("default." in k for k in state_dict)
+ if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
+ state_dict = _convert_non_diffusers_z_image_lora_to_diffusers(state_dict)
- return state_dict
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -4312,12 +5180,9 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
- # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
- state_dict = self._maybe_expand_t2v_lora_for_i2v(
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- state_dict=state_dict,
- )
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
@@ -4326,31 +5191,26 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->ZImageTransformer2DModel
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`WanTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4363,8 +5223,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -4377,38 +5239,25 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -4425,35 +5274,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4466,25 +5287,14 @@ def fuse_lora(
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
-class CogView4LoraLoaderMixin(LoraBaseMixin):
+class Flux2LoraLoaderMixin(LoraBaseMixin):
r"""
- Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
+ Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`].
"""
_lora_loadable_modules = ["transformer"]
@@ -4492,56 +5302,13 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -4554,18 +5321,16 @@ def lora_state_dict(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
- state_dict = _fetch_state_dict(
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -4586,30 +5351,23 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
- return state_dict
+ is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
+ if is_ai_toolkit:
+ state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
- self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -4625,7 +5383,8 @@ def load_lora_weights(
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
- state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
@@ -4635,31 +5394,26 @@ def load_lora_weights(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
def load_lora_into_transformer(
- cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`CogView4Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4672,8 +5426,10 @@ def load_lora_into_transformer(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
+ metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
)
@classmethod
@@ -4686,38 +5442,25 @@ def save_lora_weights(
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
-
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
+
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
@@ -4734,35 +5477,7 @@ def fuse_lora(
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4775,18 +5490,7 @@ def fuse_lora(
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 8b52cf63456c..3f8519bbfa32 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -13,14 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
+import json
import os
from functools import partial
from pathlib import Path
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
+from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
@@ -28,13 +30,13 @@
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
- get_peft_kwargs,
is_peft_available,
is_peft_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
+from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales
@@ -52,67 +54,20 @@
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
+ "AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
+ "HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
+ "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
+ "WanVACETransformer3DModel": lambda model_cls, weights: weights,
+ "ChromaTransformer2DModel": lambda model_cls, weights: weights,
+ "QwenImageTransformer2DModel": lambda model_cls, weights: weights,
+ "Flux2Transformer2DModel": lambda model_cls, weights: weights,
+ "ZImageTransformer2DModel": lambda model_cls, weights: weights,
}
-def _maybe_adjust_config(config):
- """
- We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
- (`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
- method removes the ambiguity by following what is described here:
- https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
- """
- # Track keys that have been explicitly removed to prevent re-adding them.
- deleted_keys = set()
-
- rank_pattern = config["rank_pattern"].copy()
- target_modules = config["target_modules"]
- original_r = config["r"]
-
- for key in list(rank_pattern.keys()):
- key_rank = rank_pattern[key]
-
- # try to detect ambiguity
- # `target_modules` can also be a str, in which case this loop would loop
- # over the chars of the str. The technically correct way to match LoRA keys
- # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
- # But this cuts it for now.
- exact_matches = [mod for mod in target_modules if mod == key]
- substring_matches = [mod for mod in target_modules if key in mod and mod != key]
- ambiguous_key = key
-
- if exact_matches and substring_matches:
- # if ambiguous, update the rank associated with the ambiguous key (`proj_out`, for example)
- config["r"] = key_rank
- # remove the ambiguous key from `rank_pattern` and record it as deleted
- del config["rank_pattern"][key]
- deleted_keys.add(key)
- # For substring matches, add them with the original rank only if they haven't been assigned already
- for mod in substring_matches:
- if mod not in config["rank_pattern"] and mod not in deleted_keys:
- config["rank_pattern"][mod] = original_r
-
- # Update the rest of the target modules with the original rank if not already set and not deleted
- for mod in target_modules:
- if mod != ambiguous_key and mod not in config["rank_pattern"] and mod not in deleted_keys:
- config["rank_pattern"][mod] = original_r
-
- # Handle alphas to deal with cases like:
- # https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
- has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
- if has_different_ranks:
- config["lora_alpha"] = config["r"]
- alpha_pattern = {}
- for module_name, rank in config["rank_pattern"].items():
- alpha_pattern[module_name] = rank
- config["alpha_pattern"] = alpha_pattern
-
- return config
-
-
class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -128,24 +83,17 @@ class PeftAdapterMixin:
"""
_hf_peft_config_loaded = False
+ # kwargs for prepare_model_for_compiled_hotswap, if required
+ _prepare_lora_hotswap_kwargs: Optional[dict] = None
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
- """
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
-
- Args:
- _pipeline (`DiffusionPipeline`):
- The pipeline to disable offloading for.
-
- Returns:
- tuple:
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
- """
return _func_optionally_disable_offloading(_pipeline=_pipeline)
- def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
+ def load_lora_adapter(
+ self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
+ ):
r"""
Loads a LoRA adapter into the underlying model.
@@ -189,10 +137,38 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
+ hotswap : (`bool`, *optional*)
+ Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
+ in-place. This means that, instead of loading an additional adapter, this will take the existing
+ adapter weights and replace them with the weights of the new adapter. This can be faster and more
+ memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
+ torch.compile, loading the new adapter does not require recompilation of the model. When using
+ hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
+
+ If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
+ to call an additional method before loading the adapter:
+
+ ```py
+ pipeline = ... # load diffusers pipeline
+ max_rank = ... # the highest rank among all LoRAs that you want to load
+ # call *before* compiling and loading the LoRA adapter
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
+ pipeline.load_lora_weights(file_name)
+ # optionally compile the model now
+ ```
+
+ Note that hotswapping adapters of the text encoder is not yet supported. There are some further
+ limitations to this technique, which are documented here:
+ https://huggingface.co/docs/peft/main/en/package_reference/hotswap
+ metadata:
+ LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
+ initialize `LoraConfig`.
"""
- from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
+ from peft import inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
+
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -206,6 +182,7 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
+ metadata = kwargs.pop("metadata", None)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
@@ -213,12 +190,8 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
-
- state_dict = _fetch_state_dict(
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+ state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -231,18 +204,28 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
+ metadata=metadata,
)
if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
+ if network_alphas and metadata:
+ raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
if prefix is not None:
- state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
+ state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
+ if metadata is not None:
+ metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
- if adapter_name in getattr(self, "peft_config", {}):
+ if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
)
+ elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
+ raise ValueError(
+ f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
+ "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
+ )
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
@@ -251,62 +234,112 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
rank = {}
for key, val in state_dict.items():
- # Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
+ # Cannot figure out rank from lora layers that don't have at least 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
- # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
- rank[key] = val.shape[1]
+ # Check out https://github.com/huggingface/peft/pull/2419 for the `^` symbol.
+ # We may run into some ambiguous configuration values when a model has module
+ # names, sharing a common prefix (`proj_out.weight` and `blocks.transformer.proj_out.weight`,
+ # for example) and they have different LoRA ranks.
+ rank[f"^{key}"] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
- network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
-
- lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
- # TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
- lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
-
- if "use_dora" in lora_config_kwargs:
- if lora_config_kwargs["use_dora"]:
- if is_peft_version("<", "0.9.0"):
- raise ValueError(
- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<", "0.9.0"):
- lora_config_kwargs.pop("use_dora")
-
- if "lora_bias" in lora_config_kwargs:
- if lora_config_kwargs["lora_bias"]:
- if is_peft_version("<=", "0.13.2"):
- raise ValueError(
- "You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
- )
- else:
- if is_peft_version("<=", "0.13.2"):
- lora_config_kwargs.pop("lora_bias")
+ network_alphas = {
+ k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
+ }
- lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
+ # create LoraConfig
+ lora_config = _create_lora_config(
+ state_dict,
+ network_alphas,
+ metadata,
+ rank,
+ model_state_dict=self.state_dict(),
+ adapter_name=adapter_name,
+ )
+
# =", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
- # To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
+ if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
+ if is_peft_version(">", "0.14.0"):
+ from peft.utils.hotswap import (
+ check_hotswap_configs_compatible,
+ hotswap_adapter_from_state_dict,
+ prepare_model_for_compiled_hotswap,
+ )
+ else:
+ msg = (
+ "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
+ "from source."
+ )
+ raise ImportError(msg)
+
+ if hotswap:
+
+ def map_state_dict_for_hotswap(sd):
+ # For hotswapping, we need the adapter name to be present in the state dict keys
+ new_sd = {}
+ for k, v in sd.items():
+ if k.endswith("lora_A.weight") or k.endswith("lora_B.weight"):
+ k = k[: -len(".weight")] + f".{adapter_name}.weight"
+ elif k.endswith("lora_B.bias"): # lora_bias=True option
+ k = k[: -len(".bias")] + f".{adapter_name}.bias"
+ new_sd[k] = v
+ return new_sd
+
+ # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
- inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
- incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
+ if hotswap:
+ state_dict = map_state_dict_for_hotswap(state_dict)
+ check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
+ try:
+ hotswap_adapter_from_state_dict(
+ model=self,
+ state_dict=state_dict,
+ adapter_name=adapter_name,
+ config=lora_config,
+ )
+ except Exception as e:
+ logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}")
+ raise
+ # the hotswap function raises if there are incompatible keys, so if we reach this point we can set
+ # it to None
+ incompatible_keys = None
+ else:
+ inject_adapter_in_model(
+ lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
+ )
+ incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
+
+ if self._prepare_lora_hotswap_kwargs is not None:
+ # For hotswapping of compiled models or adapters with different ranks.
+ # If the user called enable_lora_hotswap, we need to ensure it is called:
+ # - after the first adapter was loaded
+ # - before the model is compiled and the 2nd adapter is being hotswapped in
+ # Therefore, it needs to be called here
+ prepare_model_for_compiled_hotswap(
+ self, config=lora_config, **self._prepare_lora_hotswap_kwargs
+ )
+ # We only want to call prepare_model_for_compiled_hotswap once
+ self._prepare_lora_hotswap_kwargs = None
+
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
@@ -321,46 +354,28 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans
module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name)
- logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
+ logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
raise
- warn_msg = ""
- if incompatible_keys is not None:
- # Check only for unexpected keys.
- unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
- if unexpected_keys:
- lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
- if lora_unexpected_keys:
- warn_msg = (
- f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
- f" {', '.join(lora_unexpected_keys)}. "
- )
-
- # Filter missing keys specific to the current adapter.
- missing_keys = getattr(incompatible_keys, "missing_keys", None)
- if missing_keys:
- lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
- if lora_missing_keys:
- warn_msg += (
- f"Loading adapter weights from state_dict led to missing keys in the model:"
- f" {', '.join(lora_missing_keys)}."
- )
-
- if warn_msg:
- logger.warning(warn_msg)
+ _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
+ elif is_group_offload:
+ for component in _pipeline.components.values():
+ if isinstance(component, torch.nn.Module):
+ _maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
if prefix is not None and not state_dict:
+ model_class_name = self.__class__.__name__
logger.warning(
- f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. "
+ f"No LoRA keys associated to {model_class_name} found with the {prefix=}. "
"This is safe to ignore if LoRA state dict didn't originally have any "
- f"{self.__class__.__name__} related params. You can also try specifying `prefix=None` "
+ f"{model_class_name} related params. You can also try specifying `prefix=None` "
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
"https://github.com/huggingface/diffusers/issues/new"
)
@@ -383,17 +398,13 @@ def save_lora_adapter(
underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
"""
from peft.utils import get_peft_model_state_dict
- from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
+ from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if adapter_name is None:
adapter_name = get_adapter_name(self)
@@ -401,6 +412,8 @@ def save_lora_adapter(
if adapter_name not in getattr(self, "peft_config", {}):
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
+ lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
+
lora_layers_to_save = get_peft_model_state_dict(
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
)
@@ -410,7 +423,15 @@ def save_lora_adapter(
if safe_serialization:
def save_function(weights, filename):
- return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
+ # Inject framework format.
+ metadata = {"format": "pt"}
+ if lora_adapter_metadata is not None:
+ for key, value in lora_adapter_metadata.items():
+ if isinstance(value, set):
+ lora_adapter_metadata[key] = list(value)
+ metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
+
+ return safetensors.torch.save_file(weights, filename, metadata=metadata)
else:
save_function = torch.save
@@ -423,7 +444,6 @@ def save_function(weights, filename):
else:
weight_name = LORA_WEIGHT_NAME
- # TODO: we could consider saving the `peft_config` as well.
save_path = Path(save_directory, weight_name).as_posix()
save_function(lora_layers_to_save, save_path)
logger.info(f"Model weights saved in {save_path}")
@@ -434,7 +454,7 @@ def set_adapters(
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
- Set the currently active adapters for use in the UNet.
+ Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
Args:
adapter_names (`List[str]` or `str`):
@@ -456,7 +476,7 @@ def set_adapters(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
+ pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
```
"""
if not USE_PEFT_BACKEND:
@@ -654,7 +674,7 @@ def _fuse_lora_apply(self, module, adapter_names=None):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
- # For BC with prevous PEFT versions, we need to check the signature
+ # For BC with previous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
@@ -682,11 +702,16 @@ def unload_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unload_lora()`.")
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import recurse_remove_peft_layers
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
+ if hasattr(self, "_hf_peft_config_loaded"):
+ self._hf_peft_config_loaded = None
+
+ _maybe_remove_and_reapply_group_offloading(self)
def disable_lora(self):
"""
@@ -704,7 +729,7 @@ def disable_lora(self):
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
- pipeline.disable_lora()
+ pipeline.unet.disable_lora()
```
"""
if not USE_PEFT_BACKEND:
@@ -727,7 +752,7 @@ def enable_lora(self):
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
- pipeline.enable_lora()
+ pipeline.unet.enable_lora()
```
"""
if not USE_PEFT_BACKEND:
@@ -754,7 +779,7 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
)
- pipeline.delete_adapters("cinematic")
+ pipeline.unet.delete_adapters("cinematic")
```
"""
if not USE_PEFT_BACKEND:
@@ -769,3 +794,38 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
# Pop also the corresponding adapter from the config
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
+
+ _maybe_remove_and_reapply_group_offloading(self)
+
+ def enable_lora_hotswap(
+ self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
+ ) -> None:
+ """Enables the possibility to hotswap LoRA adapters.
+
+ Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
+ the loaded adapters differ.
+
+ Args:
+ target_rank (`int`, *optional*, defaults to `128`):
+ The highest rank among all the adapters that will be loaded.
+
+ check_compiled (`str`, *optional*, defaults to `"error"`):
+ How to handle the case when the model is already compiled, which should generally be avoided. The
+ options are:
+ - "error" (default): raise an error
+ - "warn": issue a warning
+ - "ignore": do nothing
+ """
+ if getattr(self, "peft_config", {}):
+ if check_compiled == "error":
+ raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
+ elif check_compiled == "warn":
+ logger.warning(
+ "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
+ )
+ elif check_compiled != "ignore":
+ raise ValueError(
+ f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
+ )
+
+ self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py
index c2843fc7406a..667f79437985 100644
--- a/src/diffusers/loaders/single_file.py
+++ b/src/diffusers/loaders/single_file.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -453,7 +453,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
logger.warning(
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
"This may lead to errors if the model components are not correctly inferred. \n"
- "To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
+ "To avoid this warning, please explicitly pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
"e.g. `from_single_file(, config=) \n"
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
"the necessary config files.\n"
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index dafdb3c26ddc..803fdfc2d952 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,15 +21,22 @@
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
+from .. import __version__
+from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map
from ..quantizers import DiffusersAutoQuantizer
-from ..utils import deprecate, is_accelerate_available, logging
+from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
+from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
+ convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
+ convert_cosmos_transformer_checkpoint_to_diffusers,
+ convert_flux2_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers,
+ convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
@@ -42,6 +49,7 @@
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
+ convert_z_image_transformer_checkpoint_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
@@ -57,8 +65,12 @@
if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights
- from ..models.modeling_utils import load_model_dict_into_meta
+ from ..models.model_loading_utils import load_model_dict_into_meta
+if is_torch_version(">=", "1.9.0") and is_accelerate_available():
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
+else:
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
@@ -95,6 +107,10 @@
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
+ "ChromaTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
@@ -128,13 +144,41 @@
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
+ "WanVACETransformer3DModel": {
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
+ "HiDreamImageTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "CosmosTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "QwenImageTransformer2DModel": {
+ "checkpoint_mapping_fn": lambda x: x,
+ "default_subfolder": "transformer",
+ },
+ "Flux2Transformer2DModel": {
+ "checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
+ "ZImageTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
}
+def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
+ return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
+
+
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -186,9 +230,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
original_config (`str`, *optional*):
Dict or path to a yaml file containing the configuration for the model in its original format.
If a dict is provided, it will be used to initialize the model configuration.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
- dtype is automatically derived from the model's weights.
+ torch_dtype (`torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -208,6 +251,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
+ is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
+ not initializing the weights. This also tries to not use more than 1x model size in CPU memory
+ (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
+ an older version of PyTorch, setting this argument to `True` will raise an error.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -257,8 +305,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
+ device_map = kwargs.pop("device_map", None)
+
+ user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
+ if quantization_config is not None:
+ user_agent["quant"] = quantization_config.quant_method.value
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
@@ -278,6 +333,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
+ user_agent=user_agent,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
@@ -355,19 +411,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
- checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
- diffusers_format_checkpoint = checkpoint_mapping_fn(
- config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
- )
- if not diffusers_format_checkpoint:
- raise SingleFileComponentError(
- f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
- )
-
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
+ model_state_dict = model.state_dict()
+
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -380,6 +429,26 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
else:
keep_in_fp32_modules = []
+ # Now that the model is loaded, we can determine the `device_map`
+ device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer)
+ if device_map is not None:
+ expanded_device_map = _expand_device_map(device_map, model_state_dict.keys())
+ _caching_allocator_warmup(model, expanded_device_map, torch_dtype, hf_quantizer)
+
+ checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
+
+ if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
+ diffusers_format_checkpoint = checkpoint_mapping_fn(
+ config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
+ )
+ else:
+ diffusers_format_checkpoint = checkpoint
+
+ if not diffusers_format_checkpoint:
+ raise SingleFileComponentError(
+ f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
+ )
+
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
@@ -389,7 +458,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
)
device_map = None
- if is_accelerate_available():
+ if low_cpu_mem_usage:
param_device = torch.device(device) if device else torch.device("cpu")
empty_state_dict = model.state_dict()
unexpected_keys = [
@@ -405,6 +474,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
+ empty_device_cache()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index 42aee4a84822..b866a5a21ae3 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -44,7 +44,9 @@
is_transformers_available,
logging,
)
+from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
+from ..utils.torch_utils import empty_device_cache
if is_transformers_available():
@@ -53,11 +55,12 @@
if is_accelerate_available():
from accelerate import init_empty_weights
- from ..models.modeling_utils import load_model_dict_into_meta
+ from ..models.model_loading_utils import load_model_dict_into_meta
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CHECKPOINT_KEY_NAMES = {
+ "v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
@@ -117,6 +120,7 @@
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
+ "z-image-turbo": "cap_embedder.0.weight",
"sana": [
"blocks.0.cross_attn.q_linear.weight",
"blocks.0.cross_attn.q_linear.bias",
@@ -125,6 +129,19 @@
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
+ "wan_vace": "vace_blocks.0.after_proj.bias",
+ "hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
+ "cosmos-1.0": [
+ "net.x_embedder.proj.1.weight",
+ "net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
+ "net.extra_pos_embedder.pos_emb_h",
+ ],
+ "cosmos-2.0": [
+ "net.x_embedder.proj.1.weight",
+ "net.blocks.0.self_attn.q_proj.weight",
+ "net.pos_embedder.dim_spatial_range",
+ ],
+ "flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -174,8 +191,11 @@
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
+ "flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
+ "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
+ "ltx-video-0.9.7": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.7-dev"},
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
@@ -188,6 +208,18 @@
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
+ "wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
+ "wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
+ "hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
+ "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
+ "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
+ "cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
+ "cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
+ "cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
+ "cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
+ "cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
+ "cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
+ "z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
}
# Use to configure model sample size when original config is provided
@@ -359,6 +391,14 @@ def is_valid_url(url):
return False
+def _is_single_file_path_or_url(pretrained_model_name_or_path):
+ if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path):
+ return False
+
+ repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path)
+ return bool(repo_id and weight_name)
+
+
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
if not is_valid_url(pretrained_model_name_or_path):
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
@@ -370,7 +410,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
match = re.match(pattern, pretrained_model_name_or_path)
if not match:
- logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
return repo_id, weights_name
repo_id = f"{match.group(1)}/{match.group(2)}"
@@ -403,13 +442,16 @@ def load_single_file_checkpoint(
local_files_only=None,
revision=None,
disable_mmap=False,
+ user_agent=None,
):
+ if user_agent is None:
+ user_agent = {"file_type": "single_file", "framework": "pytorch"}
+
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
- user_agent = {"file_type": "single_file", "framework": "pytorch"}
pretrained_model_link_or_path = _get_model_file(
repo_id,
weights_name=weights_name,
@@ -443,7 +485,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
"Please provide a valid local file path."
)
- original_config_file = BytesIO(requests.get(original_config_file).content)
+ original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
@@ -618,6 +660,9 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "animatediff_v3"
+ elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]):
+ model_type = "flux-2-dev"
+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
if any(
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
@@ -637,7 +682,12 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-schnell"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
- if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
+ has_vae = "vae.encoder.conv_in.conv.bias" in checkpoint
+ if any(key.endswith("transformer_blocks.47.scale_shift_table") for key in checkpoint):
+ model_type = "ltx-video-0.9.7"
+ elif has_vae and checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
+ model_type = "ltx-video-0.9.5"
+ elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
model_type = "ltx-video-0.9.1"
else:
model_type = "ltx-video"
@@ -673,6 +723,12 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "instruct-pix2pix"
+ elif (
+ CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
+ ):
+ model_type = "z-image-turbo"
+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
model_type = "lumina2"
@@ -685,15 +741,44 @@ def infer_diffusers_model_type(checkpoint):
else:
target_key = "patch_embedding.weight"
- if checkpoint[target_key].shape[0] == 1536:
+ if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
+ if checkpoint[target_key].shape[0] == 1536:
+ model_type = "wan-vace-1.3B"
+ elif checkpoint[target_key].shape[0] == 5120:
+ model_type = "wan-vace-14B"
+
+ elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
+
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
+
+ elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
+ model_type = "hidream"
+
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
+ if x_embedder_shape[1] == 68:
+ model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
+ elif x_embedder_shape[1] == 72:
+ model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
+ else:
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
+
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
+ if x_embedder_shape[1] == 68:
+ model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
+ elif x_embedder_shape[1] == 72:
+ model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
+ else:
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
+
else:
model_type = "v1"
@@ -1626,6 +1711,7 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
+ empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2085,6 +2171,7 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
+ empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint)
@@ -2271,7 +2358,7 @@ def swap_scale_shift(weight):
f"double_blocks.{i}.txt_attn.proj.bias"
)
- # single transfomer blocks
+ # single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
@@ -2402,14 +2489,41 @@ def remove_keys_(key: str, state_dict):
"last_scale_shift_table": "scale_shift_table",
}
+ VAE_095_RENAME_DICT = {
+ # decoder
+ "up_blocks.0": "mid_block",
+ "up_blocks.1": "up_blocks.0.upsamplers.0",
+ "up_blocks.2": "up_blocks.0",
+ "up_blocks.3": "up_blocks.1.upsamplers.0",
+ "up_blocks.4": "up_blocks.1",
+ "up_blocks.5": "up_blocks.2.upsamplers.0",
+ "up_blocks.6": "up_blocks.2",
+ "up_blocks.7": "up_blocks.3.upsamplers.0",
+ "up_blocks.8": "up_blocks.3",
+ # encoder
+ "down_blocks.0": "down_blocks.0",
+ "down_blocks.1": "down_blocks.0.downsamplers.0",
+ "down_blocks.2": "down_blocks.1",
+ "down_blocks.3": "down_blocks.1.downsamplers.0",
+ "down_blocks.4": "down_blocks.2",
+ "down_blocks.5": "down_blocks.2.downsamplers.0",
+ "down_blocks.6": "down_blocks.3",
+ "down_blocks.7": "down_blocks.3.downsamplers.0",
+ "down_blocks.8": "mid_block",
+ # common
+ "last_time_embedder": "time_embedder",
+ "last_scale_shift_table": "scale_shift_table",
+ }
+
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
- "timestep_scale_multiplier": remove_keys_,
}
- if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
+ if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
+ VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
+ elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
for key in list(converted_state_dict.keys()):
@@ -2838,7 +2952,7 @@ def calculate_layers(keys, key_prefix):
def convert_lumina2_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
- # Original Lumina-Image-2 has an extra norm paramter that is unused
+ # Original Lumina-Image-2 has an extra norm parameter that is unused
# We just remove it here
checkpoint.pop("norm_final.weight", None)
@@ -3051,6 +3165,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # For the VACE model
+ "before_proj": "proj_in",
+ "after_proj": "proj_out",
}
for key in list(checkpoint.keys()):
@@ -3259,3 +3376,512 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
converted_state_dict[key] = value
return converted_state_dict
+
+
+def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
+ keys = list(checkpoint.keys())
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ return checkpoint
+
+
+def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ keys = list(checkpoint.keys())
+
+ for k in keys:
+ if "model.diffusion_model." in k:
+ checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
+
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
+ num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
+ num_guidance_layers = (
+ list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
+ )
+ mlp_ratio = 4.0
+ inner_dim = 3072
+
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
+ def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+ # guidance
+ converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
+ "distilled_guidance_layer.in_proj.bias"
+ )
+ converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
+ "distilled_guidance_layer.in_proj.weight"
+ )
+ converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
+ "distilled_guidance_layer.out_proj.bias"
+ )
+ converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
+ "distilled_guidance_layer.out_proj.weight"
+ )
+ for i in range(num_guidance_layers):
+ block_prefix = f"distilled_guidance_layer.layers.{i}."
+ converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
+ f"distilled_guidance_layer.layers.{i}.in_layer.bias"
+ )
+ converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
+ f"distilled_guidance_layer.layers.{i}.in_layer.weight"
+ )
+ converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
+ f"distilled_guidance_layer.layers.{i}.out_layer.bias"
+ )
+ converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
+ f"distilled_guidance_layer.layers.{i}.out_layer.weight"
+ )
+ converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
+ f"distilled_guidance_layer.norms.{i}.scale"
+ )
+
+ # context_embedder
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
+ converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
+
+ # x_embedder
+ converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
+ converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ # Q, K, V
+ sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
+ context_q, context_k, context_v = torch.chunk(
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
+ )
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
+ )
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk_norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_attn.norm.key_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
+ )
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mlp.0.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mlp.0.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mlp.2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_mlp.2.bias"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.img_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.img_attn.proj.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_attn.proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
+ f"double_blocks.{i}.txt_attn.proj.bias"
+ )
+
+ # single transformer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+ q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
+ # qk norm
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
+ f"single_blocks.{i}.norm.query_norm.scale"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
+ f"single_blocks.{i}.norm.key_norm.scale"
+ )
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
+
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
+
+ return converted_state_dict
+
+
+def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ def remove_keys_(key: str, state_dict):
+ state_dict.pop(key)
+
+ def rename_transformer_blocks_(key: str, state_dict):
+ block_index = int(key.split(".")[1].removeprefix("block"))
+ new_key = key
+ old_prefix = f"blocks.block{block_index}"
+ new_prefix = f"transformer_blocks.{block_index}"
+ new_key = new_prefix + new_key.removeprefix(old_prefix)
+ state_dict[new_key] = state_dict.pop(key)
+
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
+ "t_embedder.1": "time_embed.t_embedder",
+ "affline_norm": "time_embed.norm",
+ ".blocks.0.block.attn": ".attn1",
+ ".blocks.1.block.attn": ".attn2",
+ ".blocks.2.block": ".ff",
+ ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
+ ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
+ ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
+ ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
+ ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
+ ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
+ "to_q.0": "to_q",
+ "to_q.1": "norm_q",
+ "to_k.0": "to_k",
+ "to_k.1": "norm_k",
+ "to_v.0": "to_v",
+ "layer1": "net.0.proj",
+ "layer2": "net.2",
+ "proj.1": "proj",
+ "x_embedder": "patch_embed",
+ "extra_pos_embedder": "learnable_pos_embed",
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
+ "final_layer.adaLN_modulation.2": "norm_out.linear_2",
+ "final_layer.linear": "proj_out",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
+ "blocks.block": rename_transformer_blocks_,
+ "logvar.0.freqs": remove_keys_,
+ "logvar.0.phases": remove_keys_,
+ "logvar.1.weight": remove_keys_,
+ "pos_embedder.seq": remove_keys_,
+ }
+
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
+ "t_embedder.1": "time_embed.t_embedder",
+ "t_embedding_norm": "time_embed.norm",
+ "blocks": "transformer_blocks",
+ "adaln_modulation_self_attn.1": "norm1.linear_1",
+ "adaln_modulation_self_attn.2": "norm1.linear_2",
+ "adaln_modulation_cross_attn.1": "norm2.linear_1",
+ "adaln_modulation_cross_attn.2": "norm2.linear_2",
+ "adaln_modulation_mlp.1": "norm3.linear_1",
+ "adaln_modulation_mlp.2": "norm3.linear_2",
+ "self_attn": "attn1",
+ "cross_attn": "attn2",
+ "q_proj": "to_q",
+ "k_proj": "to_k",
+ "v_proj": "to_v",
+ "output_proj": "to_out.0",
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ "mlp.layer1": "ff.net.0.proj",
+ "mlp.layer2": "ff.net.2",
+ "x_embedder.proj.1": "patch_embed.proj",
+ "final_layer.adaln_modulation.1": "norm_out.linear_1",
+ "final_layer.adaln_modulation.2": "norm_out.linear_2",
+ "final_layer.linear": "proj_out",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
+ "accum_video_sample_counter": remove_keys_,
+ "accum_image_sample_counter": remove_keys_,
+ "accum_iteration": remove_keys_,
+ "accum_train_in_hours": remove_keys_,
+ "pos_embedder.seq": remove_keys_,
+ "pos_embedder.dim_spatial_range": remove_keys_,
+ "pos_embedder.dim_temporal_range": remove_keys_,
+ "_extra_state": remove_keys_,
+ }
+
+ PREFIX_KEY = "net."
+ if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
+ else:
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
+
+ state_dict_keys = list(converted_state_dict.keys())
+ for key in state_dict_keys:
+ new_key = key[:]
+ if new_key.startswith(PREFIX_KEY):
+ new_key = new_key.removeprefix(PREFIX_KEY)
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ state_dict_keys = list(converted_state_dict.keys())
+ for key in state_dict_keys:
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
+ # Image and text input projections
+ "img_in": "x_embedder",
+ "txt_in": "context_embedder",
+ # Timestep and guidance embeddings
+ "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
+ "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
+ "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
+ "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
+ # Modulation parameters
+ "double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
+ "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
+ "single_stream_modulation.lin": "single_stream_modulation.linear",
+ # Final output layer
+ # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
+ "final_layer.linear": "proj_out",
+ }
+
+ FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
+ "final_layer.adaLN_modulation.1": "norm_out.linear",
+ }
+
+ FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
+ # Handle fused QKV projections separately as we need to break into Q, K, V projections
+ "img_attn.norm.query_norm": "attn.norm_q",
+ "img_attn.norm.key_norm": "attn.norm_k",
+ "img_attn.proj": "attn.to_out.0",
+ "img_mlp.0": "ff.linear_in",
+ "img_mlp.2": "ff.linear_out",
+ "txt_attn.norm.query_norm": "attn.norm_added_q",
+ "txt_attn.norm.key_norm": "attn.norm_added_k",
+ "txt_attn.proj": "attn.to_add_out",
+ "txt_mlp.0": "ff_context.linear_in",
+ "txt_mlp.2": "ff_context.linear_out",
+ }
+
+ FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
+ "linear1": "attn.to_qkv_mlp_proj",
+ "linear2": "attn.to_out",
+ "norm.query_norm": "attn.norm_q",
+ "norm.key_norm": "attn.norm_k",
+ }
+
+ def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
+ # Skip if not a weight, bias, or scale
+ if ".weight" not in key and ".bias" not in key and ".scale" not in key:
+ return
+
+ # Mapping:
+ # - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
+ # - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
+ # - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
+ # - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
+ new_prefix = "single_transformer_blocks"
+ if "single_blocks." in key:
+ parts = key.split(".")
+ block_idx = parts[1]
+ within_block_name = ".".join(parts[2:-1])
+ param_type = parts[-1]
+
+ if param_type == "scale":
+ param_type = "weight"
+
+ new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
+ new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
+
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+
+ return
+
+ def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None:
+ # Skip if not a weight
+ if ".weight" not in key:
+ return
+
+ # If adaLN_modulation is in the key, swap scale and shift parameters
+ # Original implementation is (shift, scale); diffusers implementation is (scale, shift)
+ if "adaLN_modulation" in key:
+ key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
+ # Assume all such keys are in the AdaLayerNorm key map
+ new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
+ new_key = ".".join([new_key_without_param_type, param_type])
+
+ swapped_weight = swap_scale_shift(state_dict.pop(key), 0)
+ state_dict[new_key] = swapped_weight
+
+ return
+
+ def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
+ # Skip if not a weight, bias, or scale
+ if ".weight" not in key and ".bias" not in key and ".scale" not in key:
+ return
+
+ new_prefix = "transformer_blocks"
+ if "double_blocks." in key:
+ parts = key.split(".")
+ block_idx = parts[1]
+ modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
+ within_block_name = ".".join(parts[2:-1])
+ param_type = parts[-1]
+
+ if param_type == "scale":
+ param_type = "weight"
+
+ if "qkv" in within_block_name:
+ fused_qkv_weight = state_dict.pop(key)
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ if "img" in modality_block_name:
+ # double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = "attn.to_q"
+ new_k_name = "attn.to_k"
+ new_v_name = "attn.to_v"
+ elif "txt" in modality_block_name:
+ # double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = "attn.add_q_proj"
+ new_k_name = "attn.add_k_proj"
+ new_v_name = "attn.add_v_proj"
+ new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
+ new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
+ new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
+ state_dict[new_q_key] = to_q_weight
+ state_dict[new_k_key] = to_k_weight
+ state_dict[new_v_key] = to_v_weight
+ else:
+ new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
+ new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
+
+ param = state_dict.pop(key)
+ state_dict[new_key] = param
+ return
+
+ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "adaLN_modulation": convert_ada_layer_norm_weights,
+ "double_blocks": convert_flux2_double_stream_blocks,
+ "single_blocks": convert_flux2_single_stream_blocks,
+ }
+
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ # Handle official code --> diffusers key remapping via the remap dict
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+
+ update_state_dict(converted_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
+
+
+def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ Z_IMAGE_KEYS_RENAME_DICT = {
+ "final_layer.": "all_final_layer.2-1.",
+ "x_embedder.": "all_x_embedder.2-1.",
+ ".attention.out.bias": ".attention.to_out.0.bias",
+ ".attention.k_norm.weight": ".attention.norm_k.weight",
+ ".attention.q_norm.weight": ".attention.norm_q.weight",
+ ".attention.out.weight": ".attention.to_out.0.weight",
+ }
+
+ def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
+ if ".attention.qkv.weight" not in key:
+ return
+
+ fused_qkv_weight = state_dict.pop(key)
+ to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
+ new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
+ new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
+ new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
+
+ state_dict[new_q_name] = to_q_weight
+ state_dict[new_k_name] = to_k_weight
+ state_dict[new_v_name] = to_v_weight
+ return
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ ".attention.qkv.weight": convert_z_image_fused_attention,
+ }
+
+ def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ # Handle single file --> diffusers key remapping via the remap dict
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+
+ update_state_dict(converted_state_dict, key, new_key)
+
+ # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
+ # special_keys_remap
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py
index 9aeb81c3e911..63fc97ed431f 100644
--- a/src/diffusers/loaders/textual_inversion.py
+++ b/src/diffusers/loaders/textual_inversion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -427,7 +427,8 @@ def load_textual_inversion(
logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
+ if is_sequential_cpu_offload or is_model_cpu_offload:
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# 7.2 save expected device and dtype
device = text_encoder.device
diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py
index 38a8a7ebe266..ef7b921b7ddf 100644
--- a/src/diffusers/loaders/transformer_flux.py
+++ b/src/diffusers/loaders/transformer_flux.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,12 +17,10 @@
ImageProjection,
MultiIPAdapterImageProjection,
)
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
-from ..utils import (
- is_accelerate_available,
- is_torch_version,
- logging,
-)
+from ..models.model_loading_utils import load_model_dict_into_meta
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
+from ..utils import is_accelerate_available, is_torch_version, logging
+from ..utils.torch_utils import empty_device_cache
if is_accelerate_available():
@@ -84,13 +82,12 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
+ empty_device_cache()
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
- from ..models.attention_processor import (
- FluxIPAdapterJointAttnProcessor2_0,
- )
+ from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
if low_cpu_mem_usage:
if is_accelerate_available():
@@ -122,7 +119,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
- attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
+ attn_processor_class = FluxIPAdapterAttnProcessor
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
@@ -158,6 +155,8 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
key_id += 1
+ empty_device_cache()
+
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py
index ece17e6728fa..e3728082efdd 100644
--- a/src/diffusers/loaders/transformer_sd3.py
+++ b/src/diffusers/loaders/transformer_sd3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +16,10 @@
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from ..models.model_loading_utils import load_model_dict_into_meta
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging
+from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__)
@@ -80,6 +82,8 @@ def _convert_ip_adapter_attn_to_diffusers(
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
+ empty_device_cache()
+
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -123,7 +127,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
updated_state_dict[key] = value
- # Image projetion parameters
+ # Image projection parameters
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
output_dim = updated_state_dict["proj_out.weight"].shape[0]
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
@@ -147,6 +151,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
+ empty_device_cache()
return image_proj
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index 1d8aba900c85..c5e56af156fc 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,7 +30,8 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
+from ..models.model_loading_utils import load_model_dict_into_meta
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
@@ -43,6 +44,7 @@
is_torch_version,
logging,
)
+from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
@@ -131,6 +133,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
)
```
"""
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
+
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -155,10 +159,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
use_safetensors = True
allow_pickle = True
- user_agent = {
- "file_type": "attn_procs_weights",
- "framework": "pytorch",
- }
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -206,6 +207,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False
is_sequential_cpu_offload = False
+ is_group_offload = False
if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -214,7 +216,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
- is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
state_dict=state_dict,
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
@@ -233,7 +235,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None:
- is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
+ _pipeline=_pipeline
+ )
# only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors)
@@ -244,6 +248,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
+ elif is_group_offload:
+ for component in _pipeline.components.values():
+ if isinstance(component, torch.nn.Module):
+ _maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
@@ -310,6 +318,7 @@ def _process_lora(
is_model_cpu_offload = False
is_sequential_cpu_offload = False
+ is_group_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0:
@@ -359,7 +368,9 @@ def _process_lora(
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
- is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
+ _pipeline
+ )
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -392,22 +403,11 @@ def _process_lora(
if warn_msg:
logger.warning(warn_msg)
- return is_model_cpu_offload, is_sequential_cpu_offload
+ return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
- """
- Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
-
- Args:
- _pipeline (`DiffusionPipeline`):
- The pipeline to disable offloading for.
-
- Returns:
- tuple:
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
- """
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def save_attn_procs(
@@ -755,6 +755,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
+ empty_device_cache()
return image_projection
@@ -852,6 +853,8 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
key_id += 2
+ empty_device_cache()
+
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py
index 8f202ed4d44b..d5b0e83cbd9e 100644
--- a/src/diffusers/loaders/unet_loader_utils.py
+++ b/src/diffusers/loaders/unet_loader_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,8 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Union
+from torch import nn
+
from ..utils import logging
@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
- unet.state_dict(),
+ model=unet,
default_scale=default_scale,
)
for weight_for_adapter in weight_scales
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
- state_dict: None,
+ model: nn.Module,
default_scale: float = 1.0,
):
"""
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
del scales[updown]
+ state_dict = model.state_dict()
for layer in scales.keys():
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
diff --git a/src/diffusers/loaders/utils.py b/src/diffusers/loaders/utils.py
index 142d72bf6b77..2d39e7bfb7d2 100644
--- a/src/diffusers/loaders/utils.py
+++ b/src/diffusers/loaders/utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 719325de13ef..13c951872cd0 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,23 +24,26 @@
_import_structure = {}
if is_torch_available():
+ _import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
+ _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
+ _import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
- _import_structure["autoencoders.autoencoder_kl_cogvideox"] = [
- "AutoencoderKLCogVideoX"
- ]
- _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = [
- "AutoencoderKLHunyuanVideo"
- ]
+ _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
+ _import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
+ _import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"]
+ _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
+ _import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
+ _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
+ _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
- _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = [
- "AutoencoderKLTemporalDecoder"
- ]
+ _import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
+ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
@@ -59,10 +62,12 @@
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DMultiControlNetModel",
]
- _import_structure["controlnets.controlnet_sd3"] = [
- "SD3ControlNetModel",
- "SD3MultiControlNetModel",
+ _import_structure["controlnets.controlnet_qwenimage"] = [
+ "QwenImageControlNetModel",
+ "QwenImageMultiControlNetModel",
]
+ _import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"]
+ _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
_import_structure["controlnets.controlnet_xs"] = [
@@ -97,35 +102,40 @@
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
- _import_structure["transformers.transformer_allegro"] = [
- "AllegroTransformer3DModel"
- ]
- _import_structure["transformers.transformer_cogview3plus"] = [
- "CogView3PlusTransformer2DModel"
- ]
- _import_structure["transformers.transformer_cogview4"] = [
- "CogView4Transformer2DModel"
- ]
- _import_structure["transformers.transformer_easyanimate"] = [
- "EasyAnimateTransformer3DModel"
- ]
+ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
+ _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
+ _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
+ _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
+ _import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
+ _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
+ _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
+ _import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
+ _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
- _import_structure["transformers.transformer_hunyuan_video"] = [
- "HunyuanVideoTransformer3DModel"
- ]
+ _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
+ _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
+ _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
+ _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
+ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
+ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
+ _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = [
"Lumina2Transformer2DModel"
]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
- _import_structure["transformers.transformer_omnigen"] = [
- "OmniGenTransformer2DModel"
- ]
+ _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
+ _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"]
+ _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
+ _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
+ _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
- _import_structure["transformers.transformer_temporal"] = [
- "TransformerTemporalModel"
- ]
+ _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
+ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
+ _import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
+ _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
+ _import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -147,17 +157,26 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
+ from ._modeling_parallel import ContextParallelConfig, ParallelConfig
from .adapter import MultiAdapter, T2IAdapter
+ from .attention_dispatch import AttentionBackendName, attention_backend
+ from .auto_model import AutoModel
from .autoencoders import (
AsymmetricAutoencoderKL,
AutoencoderDC,
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
+ AutoencoderKLCosmos,
+ AutoencoderKLFlux2,
+ AutoencoderKLHunyuanImage,
+ AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
+ AutoencoderKLHunyuanVideo15,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
+ AutoencoderKLQwenImage,
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
@@ -176,6 +195,8 @@
HunyuanDiT2DMultiControlNetModel,
MultiControlNetModel,
MultiControlNetUnionModel,
+ QwenImageControlNetModel,
+ QwenImageMultiControlNetModel,
SanaControlNetModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
@@ -187,31 +208,50 @@
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
+ BriaFiboTransformer2DModel,
+ BriaTransformer2DModel,
+ ChromaTransformer2DModel,
+ ChronoEditTransformer3DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
+ CosmosTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
+ Flux2Transformer2DModel,
FluxTransformer2DModel,
+ HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
+ HunyuanImageTransformer2DModel,
+ HunyuanVideo15Transformer3DModel,
+ HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
+ Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
+ OvisImageTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
+ PRXTransformer2DModel,
+ QwenImageTransformer2DModel,
SanaTransformer2DModel,
+ SanaVideoTransformer3DModel,
SD3Transformer2DModel,
+ SkyReelsV2Transformer3DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
+ WanAnimateTransformer3DModel,
WanTransformer3DModel,
+ WanVACETransformer3DModel,
+ ZImageTransformer2DModel,
)
from .unets import (
I2VGenXLUNet,
diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py
new file mode 100644
index 000000000000..2a4eb520c796
--- /dev/null
+++ b/src/diffusers/models/_modeling_parallel.py
@@ -0,0 +1,263 @@
+# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
+# Experimental changes are subject to change and APIs may break without warning.
+
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+
+from ..utils import get_logger
+
+
+if TYPE_CHECKING:
+ pass
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO(aryan): add support for the following:
+# - Unified Attention
+# - More dispatcher attention backends
+# - CFG/Data Parallel
+# - Tensor Parallel
+
+
+@dataclass
+class ContextParallelConfig:
+ """
+ Configuration for context parallelism.
+
+ Args:
+ ring_degree (`int`, *optional*, defaults to `1`):
+ Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
+ attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
+ of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
+ for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
+ context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
+ ulysses_degree (`int`, *optional*, defaults to `1`):
+ Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
+ local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
+ KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
+ good interconnect bandwidth.
+ convert_to_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether to convert output and LSE to float32 for ring attention numerical stability.
+ rotate_method (`str`, *optional*, defaults to `"allgather"`):
+ Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
+ is supported.
+
+ """
+
+ ring_degree: Optional[int] = None
+ ulysses_degree: Optional[int] = None
+ convert_to_fp32: bool = True
+ # TODO: support alltoall
+ rotate_method: Literal["allgather", "alltoall"] = "allgather"
+
+ _rank: int = None
+ _world_size: int = None
+ _device: torch.device = None
+ _mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ring_local_rank: int = None
+ _ulysses_local_rank: int = None
+
+ def __post_init__(self):
+ if self.ring_degree is None:
+ self.ring_degree = 1
+ if self.ulysses_degree is None:
+ self.ulysses_degree = 1
+
+ if self.ring_degree == 1 and self.ulysses_degree == 1:
+ raise ValueError(
+ "Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
+ )
+ if self.ring_degree < 1 or self.ulysses_degree < 1:
+ raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
+ if self.ring_degree > 1 and self.ulysses_degree > 1:
+ raise ValueError(
+ "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
+ )
+ if self.rotate_method != "allgather":
+ raise NotImplementedError(
+ f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
+ )
+
+ @property
+ def mesh_shape(self) -> Tuple[int, int]:
+ return (self.ring_degree, self.ulysses_degree)
+
+ @property
+ def mesh_dim_names(self) -> Tuple[str, str]:
+ """Dimension names for the device mesh."""
+ return ("ring", "ulysses")
+
+ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
+ self._rank = rank
+ self._world_size = world_size
+ self._device = device
+ self._mesh = mesh
+
+ if self.ulysses_degree * self.ring_degree > world_size:
+ raise ValueError(
+ f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
+ )
+
+ self._flattened_mesh = self._mesh._flatten()
+ self._ring_mesh = self._mesh["ring"]
+ self._ulysses_mesh = self._mesh["ulysses"]
+ self._ring_local_rank = self._ring_mesh.get_local_rank()
+ self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
+
+
+@dataclass
+class ParallelConfig:
+ """
+ Configuration for applying different parallelisms.
+
+ Args:
+ context_parallel_config (`ContextParallelConfig`, *optional*):
+ Configuration for context parallelism.
+ """
+
+ context_parallel_config: Optional[ContextParallelConfig] = None
+
+ _rank: int = None
+ _world_size: int = None
+ _device: torch.device = None
+ _mesh: torch.distributed.device_mesh.DeviceMesh = None
+
+ def setup(
+ self,
+ rank: int,
+ world_size: int,
+ device: torch.device,
+ *,
+ mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
+ ):
+ self._rank = rank
+ self._world_size = world_size
+ self._device = device
+ self._mesh = mesh
+ if self.context_parallel_config is not None:
+ self.context_parallel_config.setup(rank, world_size, device, mesh)
+
+
+@dataclass(frozen=True)
+class ContextParallelInput:
+ """
+ Configuration for splitting an input tensor across context parallel region.
+
+ Args:
+ split_dim (`int`):
+ The dimension along which to split the tensor.
+ expected_dims (`int`, *optional*):
+ The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
+ tensor has the expected number of dimensions before splitting.
+ split_output (`bool`, *optional*, defaults to `False`):
+ Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
+ This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
+ RoPE).
+ """
+
+ split_dim: int
+ expected_dims: Optional[int] = None
+ split_output: bool = False
+
+ def __repr__(self):
+ return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
+
+
+@dataclass(frozen=True)
+class ContextParallelOutput:
+ """
+ Configuration for gathering an output tensor across context parallel region.
+
+ Args:
+ gather_dim (`int`):
+ The dimension along which to gather the tensor.
+ expected_dims (`int`, *optional*):
+ The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
+ tensor has the expected number of dimensions before gathering.
+ """
+
+ gather_dim: int
+ expected_dims: Optional[int] = None
+
+ def __repr__(self):
+ return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
+
+
+# A dictionary where keys denote the input to be split across context parallel region, and the
+# value denotes the sharding configuration.
+# If the key is a string, it denotes the name of the parameter in the forward function.
+# If the key is an integer, split_output must be set to True, and it denotes the index of the output
+# to be split across context parallel region.
+ContextParallelInputType = Dict[
+ Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
+]
+
+# A dictionary where keys denote the output to be gathered across context parallel region, and the
+# value denotes the gathering configuration.
+ContextParallelOutputType = Union[
+ ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
+]
+
+# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
+# the module should be split/gathered across context parallel region.
+ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
+
+
+# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
+#
+# Each model should define a _cp_plan attribute that contains information on how to shard/gather
+# tensors at different stages of the forward:
+#
+# ```python
+# _cp_plan = {
+# "": {
+# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+# },
+# "pos_embed": {
+# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+# },
+# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+# }
+# ```
+#
+# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
+# split/gathered according to this at the respective module level. Here, the following happens:
+# - "":
+# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
+# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
+# - "pos_embed":
+# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
+# we can individually specify how they should be split
+# - "proj_out":
+# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
+# layer forward has run).
+#
+# ContextParallelInput:
+# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
+#
+# ContextParallelOutput:
+# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py
index 42e65d898cec..2d1fdb5f7d83 100644
--- a/src/diffusers/models/activations.py
+++ b/src/diffusers/models/activations.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -92,7 +92,7 @@ def forward(self, hidden_states):
class GEGLU(nn.Module):
r"""
- A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
+ A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.
Parameters:
dim_in (`int`): The number of channels in the input.
@@ -125,8 +125,8 @@ def forward(self, hidden_states, *args, **kwargs):
class SwiGLU(nn.Module):
r"""
- A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU`
- but uses SiLU / Swish instead of GeLU.
+ A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function. It's similar to
+ `GEGLU` but uses SiLU / Swish instead of GeLU.
Parameters:
dim_in (`int`): The number of channels in the input.
@@ -149,7 +149,7 @@ def forward(self, hidden_states):
class ApproximateGELU(nn.Module):
r"""
The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
- [paper](https://arxiv.org/abs/1606.08415).
+ [paper](https://huggingface.co/papers/1606.08415).
Parameters:
dim_in (`int`): The number of channels in the input.
diff --git a/src/diffusers/models/adapter.py b/src/diffusers/models/adapter.py
index 677a991f055e..e475fe6bee88 100644
--- a/src/diffusers/models/adapter.py
+++ b/src/diffusers/models/adapter.py
@@ -161,9 +161,8 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]
pretrained_model_path (`os.PathLike`):
A path to a *directory* containing model weights saved using
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
- will be automatically derived from the model's weights.
+ torch_dtype (`torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 93b11c2b43f0..8b583d1a1cce 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,23 +11,514 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
+import torch.nn as nn
import torch.nn.functional as F
-from torch import nn
from ..utils import deprecate, logging
+from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
-from .attention_processor import Attention, JointAttnProcessor2_0
+from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
+if is_xformers_available():
+ import xformers as xops
+else:
+ xops = None
+
+
logger = logging.get_logger(__name__)
+class AttentionMixin:
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+ """
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ for module in self.modules():
+ if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
+ module.fuse_projections()
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+ > [!WARNING] > This API is 🧪 experimental.
+ """
+ for module in self.modules():
+ if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
+ module.unfuse_projections()
+
+
+class AttentionModuleMixin:
+ _default_processor_cls = None
+ _available_processors = []
+ _supports_qkv_fusion = True
+ fused_projections = False
+
+ def set_processor(self, processor: AttentionProcessor) -> None:
+ """
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+ """
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ def set_attention_backend(self, backend: str):
+ from .attention_dispatch import AttentionBackendName
+
+ available_backends = {x.value for x in AttentionBackendName.__members__.values()}
+ if backend not in available_backends:
+ raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
+
+ backend = AttentionBackendName(backend.lower())
+ self.processor._attention_backend = backend
+
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
+ """
+ Set whether to use NPU flash attention from `torch_npu` or not.
+
+ Args:
+ use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
+ """
+
+ if use_npu_flash_attention:
+ if not is_torch_npu_available():
+ raise ImportError("torch_npu is not available")
+
+ self.set_attention_backend("_native_npu")
+
+ def set_use_xla_flash_attention(
+ self,
+ use_xla_flash_attention: bool,
+ partition_spec: Optional[Tuple[Optional[str], ...]] = None,
+ is_flux=False,
+ ) -> None:
+ """
+ Set whether to use XLA flash attention from `torch_xla` or not.
+
+ Args:
+ use_xla_flash_attention (`bool`):
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
+ partition_spec (`Tuple[]`, *optional*):
+ Specify the partition specification if using SPMD. Otherwise None.
+ is_flux (`bool`, *optional*, defaults to `False`):
+ Whether the model is a Flux model.
+ """
+ if use_xla_flash_attention:
+ if not is_torch_xla_available():
+ raise ImportError("torch_xla is not available")
+
+ self.set_attention_backend("_native_xla")
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ """
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ if use_memory_efficient_attention_xformers:
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ if is_xformers_available():
+ dtype = None
+ if attention_op is not None:
+ op_fw, op_bw = attention_op
+ dtype, *_ = op_fw.SUPPORTED_DTYPES
+ q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
+ _ = xops.ops.memory_efficient_attention(q, q, q)
+ except Exception as e:
+ raise e
+
+ self.set_attention_backend("xformers")
+
+ @torch.no_grad()
+ def fuse_projections(self):
+ """
+ Fuse the query, key, and value projections into a single projection for efficiency.
+ """
+ # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
+ # single stream blocks are always fused)
+ if not self._supports_qkv_fusion:
+ logger.debug(
+ f"{self.__class__.__name__} does not support fusing QKV projections, so `fuse_projections` will no-op."
+ )
+ return
+
+ # Skip if already fused
+ if getattr(self, "fused_projections", False):
+ return
+
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if hasattr(self, "is_cross_attention") and self.is_cross_attention:
+ # Fuse cross-attention key-value projections
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if hasattr(self, "use_bias") and self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+ else:
+ # Fuse self-attention projections
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if hasattr(self, "use_bias") and self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ # Handle added projections for models like SD3, Flux, etc.
+ if (
+ getattr(self, "add_q_proj", None) is not None
+ and getattr(self, "add_k_proj", None) is not None
+ and getattr(self, "add_v_proj", None) is not None
+ ):
+ concatenated_weights = torch.cat(
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
+ )
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
+ )
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ """
+ Unfuse the query, key, and value projections back to separate projections.
+ """
+ # Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
+ # single stream blocks are always fused)
+ if not self._supports_qkv_fusion:
+ return
+
+ # Skip if not fused
+ if not getattr(self, "fused_projections", False):
+ return
+
+ # Remove fused projection layers
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+
+ if hasattr(self, "to_added_qkv"):
+ delattr(self, "to_added_qkv")
+
+ self.fused_projections = False
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ """
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ processor = None
+
+ # Try to get a compatible processor for sliced attention
+ if slice_size is not None:
+ processor = self._get_compatible_processor("sliced")
+
+ # If no processor was found or slice_size is None, use default processor
+ if processor is None:
+ processor = self.default_processor_cls()
+
+ self.set_processor(processor)
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ """
+ Reshape the tensor for multi-head attention processing.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ extra_dim = 1
+ else:
+ batch_size, extra_dim, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ """
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`): The attention mask to prepare.
+ target_length (`int`): The target length of the attention mask.
+ batch_size (`int`): The batch size for repeating the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`): Output dimension.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize the encoder hidden states.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
@@ -90,7 +581,7 @@ class JointTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
- Reference: https://arxiv.org/abs/2403.03206
+ Reference: https://huggingface.co/papers/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
@@ -193,7 +684,7 @@ def forward(
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
@@ -892,8 +1383,8 @@ class FreeNoiseTransformerBlock(nn.Module):
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
weighting_scheme (`str`, defaults to `"pyramid"`):
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
- Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
- used.
+ Equation 9. of the [FreeNoise](https://huggingface.co/papers/2310.15169) paper, "pyramid" is the default
+ setting used.
"""
def __init__(
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
new file mode 100644
index 000000000000..ffad94cc7f27
--- /dev/null
+++ b/src/diffusers/models/attention_dispatch.py
@@ -0,0 +1,2334 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import functools
+import inspect
+import math
+from dataclasses import dataclass
+from enum import Enum
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+
+
+if torch.distributed.is_available():
+ import torch.distributed._functional_collectives as funcol
+
+from ..utils import (
+ get_logger,
+ is_aiter_available,
+ is_aiter_version,
+ is_flash_attn_3_available,
+ is_flash_attn_available,
+ is_flash_attn_version,
+ is_kernels_available,
+ is_sageattention_available,
+ is_sageattention_version,
+ is_torch_npu_available,
+ is_torch_version,
+ is_torch_xla_available,
+ is_torch_xla_version,
+ is_xformers_available,
+ is_xformers_version,
+)
+from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
+
+
+if TYPE_CHECKING:
+ from ._modeling_parallel import ParallelConfig
+
+_REQUIRED_FLASH_VERSION = "2.6.3"
+_REQUIRED_AITER_VERSION = "0.1.5"
+_REQUIRED_SAGE_VERSION = "2.1.1"
+_REQUIRED_FLEX_VERSION = "2.5.0"
+_REQUIRED_XLA_VERSION = "2.2"
+_REQUIRED_XFORMERS_VERSION = "0.0.29"
+
+_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
+_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
+_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
+_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
+_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
+_CAN_USE_NPU_ATTN = is_torch_npu_available()
+_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
+_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
+
+
+if _CAN_USE_FLASH_ATTN:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
+else:
+ flash_attn_func = None
+ flash_attn_varlen_func = None
+ _wrapped_flash_attn_backward = None
+ _wrapped_flash_attn_forward = None
+
+
+if _CAN_USE_FLASH_ATTN_3:
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
+else:
+ flash_attn_3_func = None
+ flash_attn_3_varlen_func = None
+
+if _CAN_USE_AITER_ATTN:
+ from aiter import flash_attn_func as aiter_flash_attn_func
+else:
+ aiter_flash_attn_func = None
+
+if _CAN_USE_SAGE_ATTN:
+ from sageattention import (
+ sageattn,
+ sageattn_qk_int8_pv_fp8_cuda,
+ sageattn_qk_int8_pv_fp8_cuda_sm90,
+ sageattn_qk_int8_pv_fp16_cuda,
+ sageattn_qk_int8_pv_fp16_triton,
+ sageattn_varlen,
+ )
+else:
+ sageattn = None
+ sageattn_qk_int8_pv_fp16_cuda = None
+ sageattn_qk_int8_pv_fp16_triton = None
+ sageattn_qk_int8_pv_fp8_cuda = None
+ sageattn_qk_int8_pv_fp8_cuda_sm90 = None
+ sageattn_varlen = None
+
+
+if _CAN_USE_FLEX_ATTN:
+ # We cannot import the flex_attention function from the package directly because it is expected (from the
+ # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
+ # compiled function.
+ import torch.nn.attention.flex_attention as flex_attention
+
+
+if _CAN_USE_NPU_ATTN:
+ from torch_npu import npu_fusion_attention
+else:
+ npu_fusion_attention = None
+
+
+if _CAN_USE_XLA_ATTN:
+ from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
+else:
+ xla_flash_attention = None
+
+
+if _CAN_USE_XFORMERS_ATTN:
+ import xformers.ops as xops
+else:
+ xops = None
+
+# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
+if torch.__version__ >= "2.4.0":
+ _custom_op = torch.library.custom_op
+ _register_fake = torch.library.register_fake
+else:
+
+ def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
+ def wrap(func):
+ return func
+
+ return wrap if fn is None else fn
+
+ def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
+ def wrap(func):
+ return func
+
+ return wrap if fn is None else fn
+
+ _custom_op = custom_op_no_op
+ _register_fake = register_fake_no_op
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+# TODO(aryan): Add support for the following:
+# - Sage Attention++
+# - block sparse, radial and other attention methods
+# - CP with sage attention, flex, xformers, other missing backends
+# - Add support for normal and CP training with backends that don't support it yet
+
+
+class AttentionBackendName(str, Enum):
+ # EAGER = "eager"
+
+ # `flash-attn`
+ FLASH = "flash"
+ FLASH_HUB = "flash_hub"
+ FLASH_VARLEN = "flash_varlen"
+ FLASH_VARLEN_HUB = "flash_varlen_hub"
+ _FLASH_3 = "_flash_3"
+ _FLASH_VARLEN_3 = "_flash_varlen_3"
+ _FLASH_3_HUB = "_flash_3_hub"
+ _FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub"
+
+ # `aiter`
+ AITER = "aiter"
+
+ # PyTorch native
+ FLEX = "flex"
+ NATIVE = "native"
+ _NATIVE_CUDNN = "_native_cudnn"
+ _NATIVE_EFFICIENT = "_native_efficient"
+ _NATIVE_FLASH = "_native_flash"
+ _NATIVE_MATH = "_native_math"
+ _NATIVE_NPU = "_native_npu"
+ _NATIVE_XLA = "_native_xla"
+
+ # `sageattention`
+ SAGE = "sage"
+ SAGE_HUB = "sage_hub"
+ SAGE_VARLEN = "sage_varlen"
+ _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
+ _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
+ _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
+ _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
+ # TODO: let's not add support for Sparge Attention now because it requires tuning per model
+ # We can look into supporting something "autotune"-ing in the future
+ # SPARGE = "sparge"
+
+ # `xformers`
+ XFORMERS = "xformers"
+
+
+class _AttentionBackendRegistry:
+ _backends = {}
+ _constraints = {}
+ _supported_arg_names = {}
+ _supports_context_parallel = set()
+ _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
+ _checks_enabled = DIFFUSERS_ATTN_CHECKS
+
+ @classmethod
+ def register(
+ cls,
+ backend: AttentionBackendName,
+ constraints: Optional[List[Callable]] = None,
+ supports_context_parallel: bool = False,
+ ):
+ logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
+
+ def decorator(func):
+ cls._backends[backend] = func
+ cls._constraints[backend] = constraints or []
+ cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
+ if supports_context_parallel:
+ cls._supports_context_parallel.add(backend.value)
+
+ return func
+
+ return decorator
+
+ @classmethod
+ def get_active_backend(cls):
+ return cls._active_backend, cls._backends[cls._active_backend]
+
+ @classmethod
+ def list_backends(cls):
+ return list(cls._backends.keys())
+
+ @classmethod
+ def _is_context_parallel_available(
+ cls,
+ backend: AttentionBackendName,
+ ) -> bool:
+ supports_context_parallel = backend.value in cls._supports_context_parallel
+ return supports_context_parallel
+
+
+@dataclass
+class _HubKernelConfig:
+ """Configuration for downloading and using a hub-based attention kernel."""
+
+ repo_id: str
+ function_attr: str
+ revision: Optional[str] = None
+ kernel_fn: Optional[Callable] = None
+
+
+# Registry for hub-based attention kernels
+_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
+ # TODO: temporary revision for now. Remove when merged upstream into `main`.
+ AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
+ repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
+ ),
+ AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
+ repo_id="kernels-community/flash-attn3",
+ function_attr="flash_attn_varlen_func",
+ # revision="fake-ops-return-probs",
+ ),
+ AttentionBackendName.FLASH_HUB: _HubKernelConfig(
+ repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
+ ),
+ AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
+ repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
+ ),
+ AttentionBackendName.SAGE_HUB: _HubKernelConfig(
+ repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
+ ),
+}
+
+
+@contextlib.contextmanager
+def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
+ """
+ Context manager to set the active attention backend.
+ """
+ if backend not in _AttentionBackendRegistry._backends:
+ raise ValueError(f"Backend {backend} is not registered.")
+
+ backend = AttentionBackendName(backend)
+ _check_attention_backend_requirements(backend)
+ _maybe_download_kernel_for_backend(backend)
+
+ old_backend = _AttentionBackendRegistry._active_backend
+ _AttentionBackendRegistry._active_backend = backend
+
+ try:
+ yield
+ finally:
+ _AttentionBackendRegistry._active_backend = old_backend
+
+
+def dispatch_attention_fn(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ *,
+ backend: Optional[AttentionBackendName] = None,
+ parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ attention_kwargs = attention_kwargs or {}
+
+ if backend is None:
+ # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
+ # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
+ backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
+ else:
+ backend_name = AttentionBackendName(backend)
+ backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
+
+ kwargs = {
+ "query": query,
+ "key": key,
+ "value": value,
+ "attn_mask": attn_mask,
+ "dropout_p": dropout_p,
+ "is_causal": is_causal,
+ "scale": scale,
+ **attention_kwargs,
+ "_parallel_config": parallel_config,
+ }
+ if is_torch_version(">=", "2.5.0"):
+ kwargs["enable_gqa"] = enable_gqa
+
+ if _AttentionBackendRegistry._checks_enabled:
+ removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
+ if removed_kwargs:
+ logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
+ for check in _AttentionBackendRegistry._constraints.get(backend_name):
+ check(**kwargs)
+
+ kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
+ return backend_fn(**kwargs)
+
+
+# ===== Checks =====
+# A list of very simple functions to catch common errors quickly when debugging.
+
+
+def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
+ if attn_mask is not None and is_causal:
+ raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
+
+
+def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.device != key.device or query.device != value.device:
+ raise ValueError("Query, key, and value must be on the same device.")
+ if query.dtype != key.dtype or query.dtype != value.dtype:
+ raise ValueError("Query, key, and value must have the same dtype.")
+
+
+def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device(query, key, value)
+ if query.device.type != "cuda":
+ raise ValueError("Query, key, and value must be on a CUDA device.")
+
+
+def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
+ def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device_cuda(query, key, value)
+ if torch.cuda.get_device_capability(query.device) < (major, minor):
+ raise ValueError(
+ f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
+ )
+
+ return check_device_cuda
+
+
+def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.dtype != key.dtype:
+ raise ValueError("Query and key must have the same dtype.")
+ if query.dtype != value.dtype:
+ raise ValueError("Query and value must have the same dtype.")
+
+
+def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_qkv_dtype_match(query, key, value)
+ if query.dtype not in (torch.bfloat16, torch.float16):
+ raise ValueError("Query, key, and value must be either bfloat16 or float16.")
+
+
+def _check_shape(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+) -> None:
+ # Expected shapes:
+ # query: (batch_size, seq_len_q, num_heads, head_dim)
+ # key: (batch_size, seq_len_kv, num_heads, head_dim)
+ # value: (batch_size, seq_len_kv, num_heads, head_dim)
+ # attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv)
+ # or (batch_size, num_heads, seq_len_q, seq_len_kv)
+ if query.shape[-1] != key.shape[-1]:
+ raise ValueError("Query and key must have the same head dimension.")
+ if key.shape[-3] != value.shape[-3]:
+ raise ValueError("Key and value must have the same sequence length.")
+ if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]:
+ raise ValueError("Attention mask must match the key's sequence length.")
+
+
+# ===== Helper functions =====
+
+
+def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
+ if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
+ if not _CAN_USE_FLASH_ATTN:
+ raise RuntimeError(
+ f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
+ )
+
+ elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
+ if not _CAN_USE_FLASH_ATTN_3:
+ raise RuntimeError(
+ f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
+ )
+
+ elif backend in [
+ AttentionBackendName.FLASH_HUB,
+ AttentionBackendName.FLASH_VARLEN_HUB,
+ AttentionBackendName._FLASH_3_HUB,
+ AttentionBackendName._FLASH_3_VARLEN_HUB,
+ AttentionBackendName.SAGE_HUB,
+ ]:
+ if not is_kernels_available():
+ raise RuntimeError(
+ f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
+ )
+
+ elif backend == AttentionBackendName.AITER:
+ if not _CAN_USE_AITER_ATTN:
+ raise RuntimeError(
+ f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
+ )
+
+ elif backend in [
+ AttentionBackendName.SAGE,
+ AttentionBackendName.SAGE_VARLEN,
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
+ ]:
+ if not _CAN_USE_SAGE_ATTN:
+ raise RuntimeError(
+ f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
+ )
+
+ elif backend == AttentionBackendName.FLEX:
+ if not _CAN_USE_FLEX_ATTN:
+ raise RuntimeError(
+ f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
+ )
+
+ elif backend == AttentionBackendName._NATIVE_NPU:
+ if not _CAN_USE_NPU_ATTN:
+ raise RuntimeError(
+ f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
+ )
+
+ elif backend == AttentionBackendName._NATIVE_XLA:
+ if not _CAN_USE_XLA_ATTN:
+ raise RuntimeError(
+ f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
+ )
+
+ elif backend == AttentionBackendName.XFORMERS:
+ if not _CAN_USE_XFORMERS_ATTN:
+ raise RuntimeError(
+ f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
+ )
+
+
+@functools.lru_cache(maxsize=128)
+def _prepare_for_flash_attn_or_sage_varlen_without_mask(
+ batch_size: int,
+ seq_len_q: int,
+ seq_len_kv: int,
+ device: Optional[torch.device] = None,
+):
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
+ seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
+ max_seqlen_q = seqlens_q.max().item()
+ max_seqlen_k = seqlens_k.max().item()
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
+
+
+def _prepare_for_flash_attn_or_sage_varlen_with_mask(
+ batch_size: int,
+ seq_len_q: int,
+ attn_mask: torch.Tensor,
+ device: Optional[torch.device] = None,
+):
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
+ seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
+ max_seqlen_q = seqlens_q.max().item()
+ max_seqlen_k = seqlens_k.max().item()
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
+
+
+def _prepare_for_flash_attn_or_sage_varlen(
+ batch_size: int,
+ seq_len_q: int,
+ seq_len_kv: int,
+ attn_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+) -> None:
+ if attn_mask is None:
+ return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
+ return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
+
+
+def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
+ """
+ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
+ FlashAttention/Sage varlen.
+
+ Supports 1D to 4D shapes and common broadcasting patterns.
+ """
+ if attn_mask.dtype != torch.bool:
+ raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
+
+ if attn_mask.ndim == 1:
+ # [seq_len_k] -> broadcast across batch
+ attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 2:
+ # [batch_size, seq_len_k]. Maybe broadcast across batch
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 3:
+ # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
+ # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
+ )
+ attn_mask = attn_mask.any(dim=1)
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 4:
+ # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
+ attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
+
+ else:
+ raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
+
+ if attn_mask.shape != (batch_size, seq_len_k):
+ raise ValueError(
+ f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
+ )
+
+ return attn_mask
+
+
+def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+
+# ===== Helpers for downloading kernels =====
+def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
+ if backend not in _HUB_KERNELS_REGISTRY:
+ return
+ config = _HUB_KERNELS_REGISTRY[backend]
+
+ if config.kernel_fn is not None:
+ return
+
+ try:
+ from kernels import get_kernel
+
+ kernel_module = get_kernel(config.repo_id, revision=config.revision)
+ kernel_func = getattr(kernel_module, config.function_attr)
+
+ # Cache the downloaded kernel function in the config object
+ config.kernel_fn = kernel_func
+
+ except Exception as e:
+ logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
+ raise
+
+
+# ===== torch op registrations =====
+# Registrations are required for fullgraph tracing compatibility
+# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
+# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
+@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
+def _wrapped_flash_attn_3(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ qv: Optional[torch.Tensor] = None,
+ q_descale: Optional[torch.Tensor] = None,
+ k_descale: Optional[torch.Tensor] = None,
+ v_descale: Optional[torch.Tensor] = None,
+ attention_chunk: int = 0,
+ softcap: float = 0.0,
+ num_splits: int = 1,
+ pack_gqa: Optional[bool] = None,
+ deterministic: bool = False,
+ sm_margin: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Hardcoded for now because pytorch does not support tuple/int type hints
+ window_size = (-1, -1)
+ out, lse, *_ = flash_attn_3_func(
+ q=q,
+ k=k,
+ v=v,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ qv=qv,
+ q_descale=q_descale,
+ k_descale=k_descale,
+ v_descale=v_descale,
+ window_size=window_size,
+ attention_chunk=attention_chunk,
+ softcap=softcap,
+ num_splits=num_splits,
+ pack_gqa=pack_gqa,
+ deterministic=deterministic,
+ sm_margin=sm_margin,
+ )
+ lse = lse.permute(0, 2, 1)
+ return out, lse
+
+
+@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
+def _(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ qv: Optional[torch.Tensor] = None,
+ q_descale: Optional[torch.Tensor] = None,
+ k_descale: Optional[torch.Tensor] = None,
+ v_descale: Optional[torch.Tensor] = None,
+ attention_chunk: int = 0,
+ softcap: float = 0.0,
+ num_splits: int = 1,
+ pack_gqa: Optional[bool] = None,
+ deterministic: bool = False,
+ sm_margin: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ window_size = (-1, -1) # noqa: F841
+ # A lot of the parameters here are not yet used in any way within diffusers.
+ # We can safely ignore for now and keep the fake op shape propagation simple.
+ batch_size, seq_len, num_heads, head_dim = q.shape
+ lse_shape = (batch_size, seq_len, num_heads)
+ return torch.empty_like(q), q.new_empty(lse_shape)
+
+
+# ===== Helper functions to use attention backends with templated CP autograd functions =====
+
+
+def _native_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ # Native attention does not return_lse
+ if return_lse:
+ raise ValueError("Native attention does not support return_lse=True")
+
+ # used for backward pass
+ if _save_ctx:
+ ctx.save_for_backward(query, key, value)
+ ctx.attn_mask = attn_mask
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.enable_gqa = enable_gqa
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+
+ return out
+
+
+def _native_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ query, key, value = ctx.saved_tensors
+
+ query.requires_grad_(True)
+ key.requires_grad_(True)
+ value.requires_grad_(True)
+
+ query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query_t,
+ key=key_t,
+ value=value_t,
+ attn_mask=ctx.attn_mask,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ enable_gqa=ctx.enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+
+ grad_out_t = grad_out.permute(0, 2, 1, 3)
+ grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
+ outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
+ )
+
+ grad_query = grad_query_t.permute(0, 2, 1, 3)
+ grad_key = grad_key_t.permute(0, 2, 1, 3)
+ grad_value = grad_value_t.permute(0, 2, 1, 3)
+
+ return grad_query, grad_key, grad_value
+
+
+# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
+# forward declaration:
+# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+def _cudnn_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
+
+ tensors_to_save = ()
+
+ # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
+ # if the input tensors are not contiguous.
+ query = query.transpose(1, 2).contiguous()
+ key = key.transpose(1, 2).contiguous()
+ value = value.transpose(1, 2).contiguous()
+ tensors_to_save += (query, key, value)
+
+ out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
+ torch.ops.aten._scaled_dot_product_cudnn_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_mask,
+ compute_log_sumexp=return_lse,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ return_debug_mask=False,
+ scale=scale,
+ )
+ )
+
+ tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
+ if _save_ctx:
+ ctx.save_for_backward(*tensors_to_save)
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.attn_mask = attn_mask
+ ctx.max_q = max_q
+ ctx.max_k = max_k
+
+ out = out.transpose(1, 2).contiguous()
+ if lse is not None:
+ lse = lse.transpose(1, 2).contiguous()
+ return (out, lse) if return_lse else out
+
+
+# backward declaration:
+# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
+def _cudnn_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
+
+ grad_out = grad_out.transpose(1, 2).contiguous()
+ key = key.transpose(1, 2).contiguous()
+ value = value.transpose(1, 2).contiguous()
+
+ # Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341
+ grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
+ grad_out,
+ query,
+ key,
+ value,
+ out,
+ logsumexp=lse,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ attn_bias=ctx.attn_mask,
+ cum_seq_q=cum_seq_q,
+ cum_seq_k=cum_seq_k,
+ max_q=ctx.max_q,
+ max_k=ctx.max_k,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ )
+ grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value))
+
+ return grad_query, grad_key, grad_value
+
+
+# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
+def _flash_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
+
+ # Hardcoded for now
+ window_size = (-1, -1)
+ softcap = 0.0
+ alibi_slopes = None
+ deterministic = False
+ grad_enabled = any(x.requires_grad for x in (query, key, value))
+
+ if scale is None:
+ scale = query.shape[-1] ** (-0.5)
+
+ # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
+ if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
+ dropout_p = dropout_p if dropout_p > 0 else 1e-30
+
+ with torch.set_grad_enabled(grad_enabled):
+ out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
+ query,
+ key,
+ value,
+ dropout_p,
+ scale,
+ is_causal,
+ window_size[0],
+ window_size[1],
+ softcap,
+ alibi_slopes,
+ return_lse,
+ )
+ lse = lse.permute(0, 2, 1)
+
+ if _save_ctx:
+ ctx.save_for_backward(query, key, value, out, lse, rng_state)
+ ctx.dropout_p = dropout_p
+ ctx.scale = scale
+ ctx.is_causal = is_causal
+ ctx.window_size = window_size
+ ctx.softcap = softcap
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+
+ return (out, lse) if return_lse else out
+
+
+def _flash_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ query, key, value, out, lse, rng_state = ctx.saved_tensors
+ grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
+
+ lse_d = _wrapped_flash_attn_backward( # noqa: F841
+ grad_out,
+ query,
+ key,
+ value,
+ out,
+ lse,
+ grad_query,
+ grad_key,
+ grad_value,
+ ctx.dropout_p,
+ ctx.scale,
+ ctx.is_causal,
+ ctx.window_size[0],
+ ctx.window_size[1],
+ ctx.softcap,
+ ctx.alibi_slopes,
+ ctx.deterministic,
+ rng_state,
+ )
+
+ # Head dimension may have been padded
+ grad_query = grad_query[..., : grad_out.shape[-1]]
+ grad_key = grad_key[..., : grad_out.shape[-1]]
+ grad_value = grad_value[..., : grad_out.shape[-1]]
+
+ return grad_query, grad_key, grad_value
+
+
+def _sage_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not yet supported for Sage attention.")
+ if dropout_p > 0.0:
+ raise ValueError("`dropout_p` is not yet supported for Sage attention.")
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
+
+ out = sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+ lse = None
+ if return_lse:
+ out, lse, *_ = out
+ lse = lse.permute(0, 2, 1)
+
+ return (out, lse) if return_lse else out
+
+
+def _sage_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+):
+ raise NotImplementedError("Backward pass is not implemented for Sage attention.")
+
+
+# ===== Context parallel =====
+
+
+# Reference:
+# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
+# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
+# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
+def _wait_tensor(tensor):
+ if isinstance(tensor, funcol.AsyncCollectiveTensor):
+ tensor = tensor.wait()
+ return tensor
+
+
+def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
+ shape = x.shape
+ # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization
+ # to benchmark triton codegen fails somewhere:
+ # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
+ # ValueError: Tensors must be contiguous
+ x = x.flatten()
+ x = funcol.all_to_all_single(x, None, None, group)
+ x = x.reshape(shape)
+ x = _wait_tensor(x)
+ return x
+
+
+class TemplatedRingAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ring_mesh = _parallel_config.context_parallel_config._ring_mesh
+ rank = _parallel_config.context_parallel_config._ring_local_rank
+ world_size = _parallel_config.context_parallel_config.ring_degree
+ next_rank = (rank + 1) % world_size
+ prev_out = prev_lse = None
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx.q_shape = query.shape
+ ctx.kv_shape = key.shape
+ ctx._parallel_config = _parallel_config
+
+ kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
+ kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
+ kv_buffer = kv_buffer.chunk(world_size)
+
+ for i in range(world_size):
+ if i > 0:
+ kv = kv_buffer[next_rank]
+ key_numel = key.numel()
+ key = kv[:key_numel].reshape_as(key)
+ value = kv[key_numel:].reshape_as(value)
+ next_rank = (next_rank + 1) % world_size
+
+ out, lse = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ True,
+ _save_ctx=i == 0,
+ _parallel_config=_parallel_config,
+ )
+
+ if _parallel_config.context_parallel_config.convert_to_fp32:
+ out = out.to(torch.float32)
+ lse = lse.to(torch.float32)
+
+ lse = lse.unsqueeze(-1)
+ if prev_out is not None:
+ out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
+ lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
+ prev_out = out
+ prev_lse = lse
+
+ out = out.to(query.dtype)
+ lse = lse.squeeze(-1)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ ring_mesh = ctx._parallel_config.context_parallel_config._ring_mesh
+ rank = ctx._parallel_config.context_parallel_config._ring_local_rank
+ world_size = ctx._parallel_config.context_parallel_config.ring_degree
+ next_rank = (rank + 1) % world_size
+ next_ranks = list(range(1, world_size)) + [0]
+
+ accum_dtype = torch.float32 if ctx._parallel_config.context_parallel_config.convert_to_fp32 else grad_out.dtype
+ grad_query = torch.zeros(ctx.q_shape, dtype=accum_dtype, device=grad_out.device)
+ grad_key = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
+ grad_value = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
+ next_grad_kv = None
+
+ query, key, value, *_ = ctx.saved_tensors
+ kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
+ kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
+ kv_buffer = kv_buffer.chunk(world_size)
+
+ for i in range(world_size):
+ if i > 0:
+ kv = kv_buffer[next_rank]
+ key_numel = key.numel()
+ key = kv[:key_numel].reshape_as(key)
+ value = kv[key_numel:].reshape_as(value)
+ next_rank = (next_rank + 1) % world_size
+
+ grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
+
+ if i > 0:
+ grad_kv_buffer = _wait_tensor(next_grad_kv)
+ grad_key_numel = grad_key.numel()
+ grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
+ grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)
+
+ grad_query += grad_query_op
+ grad_key += grad_key_op
+ grad_value += grad_value_op
+
+ if i < world_size - 1:
+ grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous()
+ next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group())
+
+ grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+
+
+class TemplatedUlyssesAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ world_size = _parallel_config.context_parallel_config.ulysses_degree
+ group = ulysses_mesh.get_group()
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx._parallel_config = _parallel_config
+
+ B, S_Q_LOCAL, H, D = query.shape
+ _, S_KV_LOCAL, _, _ = key.shape
+ H_LOCAL = H // world_size
+ query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
+ query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
+
+ out = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ _save_ctx=True,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
+ out = _all_to_all_single(out, group)
+ out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
+
+ if return_lse:
+ lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
+ lse = _all_to_all_single(lse, group)
+ lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
+ else:
+ lse = None
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ ulysses_mesh = ctx._parallel_config.context_parallel_config._ulysses_mesh
+ world_size = ctx._parallel_config.context_parallel_config.ulysses_degree
+ group = ulysses_mesh.get_group()
+
+ B, S_LOCAL, H, D = grad_out.shape
+ H_LOCAL = H // world_size
+
+ grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ grad_out = _all_to_all_single(grad_out, group)
+ grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
+
+ grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
+
+ grad_query, grad_key, grad_value = (
+ x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
+ for x in (grad_query_op, grad_key_op, grad_value_op)
+ )
+ grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value))
+ grad_query, grad_key, grad_value = (
+ x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+
+
+def _templated_context_parallel_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ *,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("Attention mask is not yet supported for templated attention.")
+ if is_causal:
+ raise ValueError("Causal attention is not yet supported for templated attention.")
+ if enable_gqa:
+ raise ValueError("GQA is not yet supported for templated attention.")
+
+ # TODO: add support for unified attention with ring/ulysses degree both being > 1
+ if _parallel_config.context_parallel_config.ring_degree > 1:
+ return TemplatedRingAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ elif _parallel_config.context_parallel_config.ulysses_degree > 1:
+ return TemplatedUlyssesAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ else:
+ raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")
+
+
+# ===== Attention backends =====
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
+)
+def _flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ lse = None
+ if _parallel_config is None:
+ out = flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_attn_probs=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ dropout_p,
+ is_causal,
+ scale,
+ False,
+ return_lse,
+ forward_op=_flash_attention_forward_op,
+ backward_op=_flash_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH_HUB,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=False,
+)
+def _flash_attention_hub(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ lse = None
+ func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
+ out = func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_attn_probs=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH_VARLEN_HUB,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=False,
+)
+def _flash_varlen_attention_hub(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn
+ out = func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_attn_probs=return_lse,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH_VARLEN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = flash_attn_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_attn_probs=return_lse,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_3,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ out, lse = _wrapped_flash_attn_3(
+ q=query,
+ k=key,
+ v=value,
+ softmax_scale=scale,
+ causal=is_causal,
+ )
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_3_HUB,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=False,
+)
+def _flash_attention_3_hub(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if _parallel_config:
+ raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
+
+ func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
+ out = func(
+ q=query,
+ k=key,
+ v=value,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ return_attn_probs=return_attn_probs,
+ )
+ # When `return_attn_probs` is True, the above returns a tuple of
+ # actual outputs and lse.
+ return (out[0], out[1]) if return_attn_probs else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_3_VARLEN_HUB,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=False,
+)
+def _flash_attention_3_varlen_hub(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
+ out, lse, *_ = func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ softmax_scale=scale,
+ causal=is_causal,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_VARLEN_3,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_varlen_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out, lse, *_ = flash_attn_3_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ softmax_scale=scale,
+ causal=is_causal,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.AITER,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _aiter_flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if not return_lse and torch.is_grad_enabled():
+ # aiter requires return_lse=True by assertion when gradients are enabled.
+ out, lse, *_ = aiter_flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_lse=True,
+ )
+ else:
+ out = aiter_flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_lse=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLEX,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _native_flex_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ # TODO: should we LRU cache the block mask creation?
+ score_mod = None
+ block_mask = None
+ batch_size, seq_len_q, num_heads, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
+ block_mask = attn_mask
+ elif is_causal:
+ block_mask = flex_attention.create_block_mask(
+ _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
+ )
+ elif torch.is_tensor(attn_mask):
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+
+ attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
+
+ if attn_mask.dtype == torch.bool:
+ # TODO: this probably does not work but verify!
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+
+ block_mask = flex_attention.create_block_mask(
+ mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
+ )
+ else:
+
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
+ return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+ else:
+ raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = flex_attention.flex_attention(
+ query=query,
+ key=key,
+ value=value,
+ score_mod=score_mod,
+ block_mask=block_mask,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ return_lse=return_lse,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.NATIVE,
+ constraints=[_check_device, _check_shape],
+ supports_context_parallel=True,
+)
+def _native_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native attention backend does not support setting `return_lse=True`.")
+ if _parallel_config is None:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op=_native_attention_forward_op,
+ backward_op=_native_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_CUDNN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
+)
+def _native_cudnn_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ lse = None
+ if _parallel_config is None and not return_lse:
+ query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op=_cudnn_attention_forward_op,
+ backward_op=_cudnn_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_EFFICIENT,
+ constraints=[_check_device, _check_shape],
+)
+def _native_efficient_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.")
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_FLASH,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native flash attention backend does not support setting `return_lse=True`.")
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=None, # not supported
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_MATH,
+ constraints=[_check_device, _check_shape],
+)
+def _native_math_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native math attention backend does not support setting `return_lse=True`.")
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_NPU,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_npu_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
+ query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
+ out = npu_fusion_attention(
+ query,
+ key,
+ value,
+ query.size(1), # num_heads
+ input_layout="BNSD",
+ pse=None,
+ scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0 - dropout_p,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ out = out.transpose(1, 2).contiguous()
+ return out
+
+
+# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_XLA,
+ constraints=[_check_device, _check_shape],
+)
+def _native_xla_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ query = query / math.sqrt(query.shape[-1])
+ out = xla_flash_attention(
+ q=query,
+ k=key,
+ v=value,
+ causal=is_causal,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
+)
+def _sage_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ lse = None
+ if _parallel_config is None:
+ out = sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ 0.0,
+ is_causal,
+ scale,
+ False,
+ return_lse,
+ forward_op=_sage_attention_forward_op,
+ backward_op=_sage_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE_HUB,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=False,
+)
+def _sage_attention_hub(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ lse = None
+ func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
+ if _parallel_config is None:
+ out = func(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE_VARLEN,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _sage_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
+
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = sageattn_varlen(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ is_causal=is_causal,
+ sm_scale=scale,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp8_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp16_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp16_triton_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_triton(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.XFORMERS,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _xformers_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("xformers attention backend does not support setting `return_lse=True`.")
+
+ batch_size, seq_len_q, num_heads_q, _ = query.shape
+ _, seq_len_kv, num_heads_kv, _ = key.shape
+
+ if is_causal:
+ attn_mask = xops.LowerTriangularMask()
+ elif attn_mask is not None:
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+ elif attn_mask.ndim != 4:
+ raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
+ attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
+
+ if enable_gqa:
+ if num_heads_q % num_heads_kv != 0:
+ raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
+ num_heads_per_group = num_heads_q // num_heads_kv
+ query = query.unflatten(2, (num_heads_kv, -1))
+ key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+ value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+
+ out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
+
+ if enable_gqa:
+ out = out.flatten(2, 3)
+
+ return out
diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py
index 246f3afaf57c..1bde62e5c666 100644
--- a/src/diffusers/models/attention_flax.py
+++ b/src/diffusers/models/attention_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,6 +19,11 @@
import jax
import jax.numpy as jnp
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
@@ -75,7 +80,7 @@ def jax_memory_efficient_attention(
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
):
r"""
- Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
+ Flax Memory-efficient multi-head dot product attention. https://huggingface.co/papers/2112.05682v2
https://github.com/AminRezaei0x443/memory-efficient-attention
Args:
@@ -121,7 +126,7 @@ def chunk_scanner(chunk_idx, _):
class FlaxAttention(nn.Module):
r"""
- A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762
+ A Flax multi-head attention module as described in: https://huggingface.co/papers/1706.03762
Parameters:
query_dim (:obj:`int`):
@@ -133,7 +138,7 @@ class FlaxAttention(nn.Module):
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
- enable memory efficient attention https://arxiv.org/abs/2112.05682
+ enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5
@@ -244,7 +254,7 @@ def __call__(self, hidden_states, context=None, deterministic=True):
class FlaxBasicTransformerBlock(nn.Module):
r"""
A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:
- https://arxiv.org/abs/1706.03762
+ https://huggingface.co/papers/1706.03762
Parameters:
@@ -261,7 +271,7 @@ class FlaxBasicTransformerBlock(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
- enable memory efficient attention https://arxiv.org/abs/2112.05682
+ enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
split_head_dim: bool = False
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim,
@@ -328,7 +343,7 @@ def __call__(self, hidden_states, context, deterministic=True):
class FlaxTransformer2DModel(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
- https://arxiv.org/pdf/1506.02025.pdf
+ https://huggingface.co/papers/1506.02025
Parameters:
@@ -347,7 +362,7 @@ class FlaxTransformer2DModel(nn.Module):
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
- enable memory efficient attention https://arxiv.org/abs/2112.05682
+ enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
split_head_dim: bool = False
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
@@ -436,7 +456,7 @@ class FlaxFeedForward(nn.Module):
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
[`FeedForward`] class, with the following simplifications:
- The activation function is currently hardcoded to a gated linear unit from:
- https://arxiv.org/abs/2002.05202
+ https://huggingface.co/papers/2002.05202
- `dim_out` is equal to `dim`.
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
@@ -468,7 +493,7 @@ def __call__(self, hidden_states, deterministic=True):
class FlaxGEGLU(nn.Module):
r"""
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
- https://arxiv.org/abs/2002.05202.
+ https://huggingface.co/papers/2002.05202.
Parameters:
dim (:obj:`int`):
@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 34276a544160..66455d733aee 100755
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -203,8 +203,8 @@ def __init__(
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "rms_norm":
- self.norm_q = RMSNorm(dim_head, eps=eps)
- self.norm_k = RMSNorm(dim_head, eps=eps)
+ self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
@@ -2272,554 +2272,6 @@ def __call__(
return hidden_states
-class FluxAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxAttnProcessor2_0_NPU:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedFluxAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedFluxAttnProcessor2_0_NPU:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
- """Flux Attention processor for IP-Adapter."""
-
- def __init__(
- self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
- ):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
-
- if not isinstance(num_tokens, (tuple, list)):
- num_tokens = [num_tokens]
-
- if not isinstance(scale, list):
- scale = [scale] * len(num_tokens)
- if len(scale) != len(num_tokens):
- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
- self.scale = scale
-
- self.to_k_ip = nn.ModuleList(
- [
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
- for _ in range(len(num_tokens))
- ]
- )
- self.to_v_ip = nn.ModuleList(
- [
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
- for _ in range(len(num_tokens))
- ]
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ip_hidden_states: Optional[List[torch.Tensor]] = None,
- ip_adapter_masks: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- hidden_states_query_proj = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- # IP-adapter
- ip_query = hidden_states_query_proj
- ip_attn_output = torch.zeros_like(hidden_states)
-
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
- ):
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- current_ip_hidden_states = F.scaled_dot_product_attention(
- ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
- current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
- ip_attn_output += scale * current_ip_hidden_states
-
- return hidden_states, encoder_hidden_states, ip_attn_output
- else:
- return hidden_states
-
-
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -3449,106 +2901,6 @@ def __call__(
return hidden_states
-class XLAFluxFlashAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
- """
-
- def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
- if is_torch_xla_version("<", "2.3"):
- raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
- if is_spmd() and is_torch_xla_version("<", "2.4"):
- raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
- self.partition_spec = partition_spec
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- query /= math.sqrt(head_dim)
- hidden_states = flash_attention(query, key, value, causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
class MochiVaeAttnProcessor2_0:
r"""
Attention processor used in Mochi VAE.
@@ -3972,7 +3324,7 @@ class PAGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
- variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
+ variant of the processor employs [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377).
"""
def __init__(self):
@@ -4095,7 +3447,7 @@ class PAGCFGHunyuanAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
- variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
+ variant of the processor employs [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377).
"""
def __init__(self):
@@ -4317,11 +3669,7 @@ class FusedAttnProcessor2_0:
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is currently 🧪 experimental in nature and can change in future.
-
-
+ > [!WARNING] > This API is currently 🧪 experimental in nature and can change in future.
"""
def __init__(self):
@@ -4828,7 +4176,7 @@ def __call__(
class SpatialNorm(nn.Module):
"""
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
+ Spatially conditioned normalization as defined in https://huggingface.co/papers/2209.09002.
Args:
f_channels (`int`):
@@ -5693,7 +5041,7 @@ def __call__(
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- PAG reference: https://arxiv.org/abs/2403.17377
+ PAG reference: https://huggingface.co/papers/2403.17377
"""
def __init__(self):
@@ -5792,7 +5140,7 @@ def __call__(
class PAGCFGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- PAG reference: https://arxiv.org/abs/2403.17377
+ PAG reference: https://huggingface.co/papers/2403.17377
"""
def __init__(self):
@@ -5988,17 +5336,6 @@ def __init__(self):
pass
-class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(self):
- deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
- deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
- super().__init__()
-
-
class SanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
@@ -6163,6 +5500,111 @@ def __call__(
return hidden_states
+class FluxAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
+ deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FluxSingleAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
+ deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FusedFluxAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
+ deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FluxIPAdapterJointAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
+ deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
+
+ return FluxIPAdapterAttnProcessor(*args, **kwargs)
+
+
+class FluxAttnProcessor2_0_NPU:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
+ "alternative solution to use NPU Flash Attention will be provided in the future."
+ )
+ deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ processor = FluxAttnProcessor()
+ processor._attention_backend = "_native_npu"
+ return processor
+
+
+class FusedFluxAttnProcessor2_0_NPU:
+ def __new__(self):
+ deprecation_message = (
+ "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
+ "alternative solution to use NPU Flash Attention will be provided in the future."
+ )
+ deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ processor = FluxAttnProcessor()
+ processor._attention_backend = "_fused_npu"
+ return processor
+
+
+class XLAFluxFlashAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
+ "alternative solution to using XLA Flash Attention will be provided in the future."
+ )
+ deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+
+ if is_torch_xla_version("<", "2.3"):
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
+ deprecation_message = (
+ "partition_spec was not used in the processor implementation when it was added. Passing it "
+ "is a no-op and support for it will be removed."
+ )
+ deprecate("partition_spec", "1.0.0", deprecation_message)
+
+ processor = FluxAttnProcessor(*args, **kwargs)
+ processor._attention_backend = "_native_xla"
+ return processor
+
+
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py
new file mode 100644
index 000000000000..c96b4fa88c49
--- /dev/null
+++ b/src/diffusers/models/auto_model.py
@@ -0,0 +1,223 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Optional, Union
+
+from huggingface_hub.utils import validate_hf_hub_args
+
+from ..configuration_utils import ConfigMixin
+from ..utils import logging
+from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+
+
+logger = logging.get_logger(__name__)
+
+
+class AutoModel(ConfigMixin):
+ config_name = "config.json"
+
+ def __init__(self, *args, **kwargs):
+ raise EnvironmentError(
+ f"{self.__class__.__name__} is designed to be instantiated "
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
+ f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
+ )
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
+ r"""
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
+
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
+ train the model, set it back in training mode with `model.train()`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`~ModelMixin.save_pretrained`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ torch_dtype (`torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info (`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ from_flax (`bool`, *optional*, defaults to `False`):
+ Load the model weights from a Flax checkpoint save file.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ mirror (`str`, *optional*):
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
+ information.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
+ same device. Defaults to `None`, meaning that the model will be loaded on CPU.
+
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
+ each GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ The path to offload weights if `device_map` contains the value `"disk"`.
+ offload_state_dict (`bool`, *optional*):
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
+ when there is some disk offload.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ variant (`str`, *optional*):
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
+ loading `from_flax`.
+ use_safetensors (`bool`, *optional*, defaults to `None`):
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
+ weights. If set to `False`, `safetensors` weights are not loaded.
+ disable_mmap ('bool', *optional*, defaults to 'False'):
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
+ trust_remote_cocde (`bool`, *optional*, defaults to `False`):
+ Whether to trust remote code
+
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
+ firewalled environment.
+
+ Example:
+
+ ```py
+ from diffusers import AutoModel
+
+ unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
+ ```
+
+ If you get the error message below, you need to finetune the weights for your downstream task:
+
+ ```bash
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
+ ```
+ """
+ subfolder = kwargs.pop("subfolder", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", False)
+
+ hub_kwargs_names = [
+ "cache_dir",
+ "force_download",
+ "local_files_only",
+ "proxies",
+ "revision",
+ "token",
+ ]
+ hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
+
+ # load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
+ load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
+
+ library = None
+ orig_class_name = None
+
+ # Always attempt to fetch model_index.json first
+ try:
+ cls.config_name = "model_index.json"
+ config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
+
+ if subfolder is not None and subfolder in config:
+ library, orig_class_name = config[subfolder]
+ load_config_kwargs.update({"subfolder": subfolder})
+
+ except EnvironmentError as e:
+ logger.debug(e)
+
+ # Unable to load from model_index.json so fallback to loading from config
+ if library is None and orig_class_name is None:
+ cls.config_name = "config.json"
+ config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)
+
+ if "_class_name" in config:
+ # If we find a class name in the config, we can try to load the model as a diffusers model
+ orig_class_name = config["_class_name"]
+ library = "diffusers"
+ load_config_kwargs.update({"subfolder": subfolder})
+ elif "model_type" in config:
+ orig_class_name = "AutoModel"
+ library = "transformers"
+ load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder})
+ else:
+ raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
+
+ has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
+ trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
+ if not has_remote_code and trust_remote_code:
+ raise ValueError(
+ "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
+ )
+
+ if has_remote_code and trust_remote_code:
+ class_ref = config["auto_map"][cls.__name__]
+ module_file, class_name = class_ref.split(".")
+ module_file = module_file + ".py"
+ model_cls = get_class_from_dynamic_module(
+ pretrained_model_or_path,
+ subfolder=subfolder,
+ module_file=module_file,
+ class_name=class_name,
+ **hub_kwargs,
+ )
+ else:
+ from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
+
+ model_cls, _ = get_class_obj_and_candidates(
+ library_name=library,
+ class_name=orig_class_name,
+ importable_classes=ALL_IMPORTABLE_CLASSES,
+ pipelines=None,
+ is_pipeline_module=False,
+ )
+
+ if model_cls is None:
+ raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
+
+ kwargs = {**load_config_kwargs, **kwargs}
+ return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index f8f49ce4c797..56df27f93cd7 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -3,10 +3,16 @@
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
+from .autoencoder_kl_cosmos import AutoencoderKLCosmos
+from .autoencoder_kl_flux2 import AutoencoderKLFlux2
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
+from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
+from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
+from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
+from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_oobleck import AutoencoderOobleck
diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
index c643dcc72a34..fa49fcfe79f8 100644
--- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,13 +20,13 @@
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
-class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
+class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
- Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
- for encoding images into latents and decoding latent representations into images.
+ Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
+ KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
@@ -57,7 +57,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
"""
_skip_layerwise_casting_patterns = ["decoder"]
@@ -107,9 +107,6 @@ def __init__(
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
- self.use_slicing = False
- self.use_tiling = False
-
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py
index 9146aa5c7c6c..ec301ef8ad51 100644
--- a/src/diffusers/models/autoencoders/autoencoder_dc.py
+++ b/src/diffusers/models/autoencoders/autoencoder_dc.py
@@ -1,4 +1,4 @@
-# Copyright 2024 MIT, Tsinghua University, NVIDIA CORPORATION and The HuggingFace Team.
+# Copyright 2025 MIT, Tsinghua University, NVIDIA CORPORATION and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -27,7 +27,7 @@
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
-from .vae import DecoderOutput, EncoderOutput
+from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
class ResBlock(nn.Module):
@@ -102,7 +102,7 @@ def get_block(
attention_head_dim: int,
norm_type: str,
act_fn: str,
- qkv_mutliscales: Tuple[int] = (),
+ qkv_mutliscales: Tuple[int, ...] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
@@ -206,8 +206,8 @@ def __init__(
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
- block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
- layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
+ layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
downsample_block_type: str = "pixel_unshuffle",
out_shortcut: bool = True,
@@ -292,13 +292,14 @@ def __init__(
latent_channels: int,
attention_head_dim: int = 32,
block_type: Union[str, Tuple[str]] = "ResBlock",
- block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
- layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
+ layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
norm_type: Union[str, Tuple[str]] = "rms_norm",
act_fn: Union[str, Tuple[str]] = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
+ conv_act_fn: str = "relu",
):
super().__init__()
@@ -349,7 +350,7 @@ def __init__(
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
- self.conv_act = nn.ReLU()
+ self.conv_act = get_activation(conv_act_fn)
self.conv_out = None
if layers_per_block[0] > 0:
@@ -377,10 +378,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states
-class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
- An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
- [SANA](https://arxiv.org/abs/2410.10629).
+ An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
+ [SANA](https://huggingface.co/papers/2410.10629).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The normalization type(s) to use in the decoder.
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
The activation function(s) to use in the decoder.
+ encoder_out_shortcut (`bool`, defaults to `True`):
+ Whether to use shortcut at the end of the encoder.
+ decoder_in_shortcut (`bool`, defaults to `True`):
+ Whether to use shortcut at the beginning of the decoder.
+ decoder_conv_act_fn (`str`, defaults to `"relu"`):
+ The activation function to use at the end of the decoder.
scaling_factor (`float`, defaults to `1.0`):
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
@@ -433,14 +440,17 @@ def __init__(
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
- encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
- decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
+ encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
+ decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
upsample_block_type: str = "pixel_shuffle",
downsample_block_type: str = "pixel_unshuffle",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
+ encoder_out_shortcut: bool = True,
+ decoder_in_shortcut: bool = True,
+ decoder_conv_act_fn: str = "relu",
scaling_factor: float = 1.0,
) -> None:
super().__init__()
@@ -454,6 +464,7 @@ def __init__(
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
+ out_shortcut=encoder_out_shortcut,
)
self.decoder = Decoder(
in_channels=in_channels,
@@ -466,6 +477,8 @@ def __init__(
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
+ in_shortcut=decoder_in_shortcut,
+ conv_act_fn=decoder_conv_act_fn,
)
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
@@ -523,27 +536,6 @@ def enable_tiling(
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
- def disable_tiling(self) -> None:
- r"""
- Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
- decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
@@ -604,7 +596,7 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
returned.
"""
if self.use_slicing and z.size(0) > 1:
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py
index 357df0c31087..95991dca3304 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -21,21 +21,23 @@
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
-class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
+class AutoencoderKL(
+ ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin
+):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -60,11 +62,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
+ can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
mid_block_add_attention (`bool`, *optional*, default to `True`):
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
mid_block will only have resnet blocks
@@ -72,15 +74,16 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
+ _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
- block_out_channels: Tuple[int] = (64,),
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
@@ -90,7 +93,7 @@ def __init__(
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
- force_upcast: float = True,
+ force_upcast: bool = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
mid_block_add_attention: bool = True,
@@ -138,95 +141,6 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
- def enable_tiling(self, use_tiling: bool = True):
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.use_tiling = use_tiling
-
- def disable_tiling(self):
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.enable_tiling(False)
-
- def enable_slicing(self):
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self):
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
@@ -532,11 +446,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -556,11 +466,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
index a76277366c09..6756586460d3 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The RhymesAI and The HuggingFace Team.
+# Copyright 2025 The RhymesAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,6 +28,7 @@
from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D
+from .vae import AutoencoderMixin
class AllegroTemporalConvLayer(nn.Module):
@@ -673,7 +674,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
return sample
-class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
+class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro).
@@ -712,11 +713,11 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
force_upcast (`bool`, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
+ can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@@ -795,35 +796,6 @@ def __init__(
sample_size - self.tile_overlap_w,
)
- def enable_tiling(self) -> None:
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.use_tiling = True
-
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
index e2b26396899f..79433f7b9232 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -29,7 +29,7 @@
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -148,8 +148,8 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non
class CogVideoXSpatialNorm3D(nn.Module):
r"""
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
- to 3D-video like data.
+ Spatially conditioned normalization as defined in https://huggingface.co/papers/2209.09002. This implementation is
+ specific to 3D-video like data.
CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
@@ -955,7 +955,7 @@ def forward(
return hidden_states, new_conv_cache
-class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
@@ -980,11 +980,11 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
+ can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@@ -995,19 +995,19 @@ def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
- down_block_types: Tuple[str] = (
+ down_block_types: Tuple[str, ...] = (
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
- up_block_types: Tuple[str] = (
+ up_block_types: Tuple[str, ...] = (
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
- block_out_channels: Tuple[int] = (128, 256, 256, 512),
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
latent_channels: int = 16,
layers_per_block: int = 3,
act_fn: str = "silu",
@@ -1124,27 +1124,6 @@ def enable_tiling(
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
new file mode 100644
index 000000000000..b17522d1c424
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
@@ -0,0 +1,1089 @@
+# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import get_logger
+from ...utils.accelerate_utils import apply_forward_hook
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
+
+
+logger = get_logger(__name__)
+
+
+# fmt: off
+# These latents and means are from CV8x8x8-1.0. Each checkpoint has different values, but since this is the main VAE used,
+# we will default to these values.
+LATENTS_MEAN = [0.11362758, -0.0171717, 0.03071163, 0.02046862, 0.01931456, 0.02138567, 0.01999342, 0.02189187, 0.02011935, 0.01872694, 0.02168613, 0.02207148, 0.01986941, 0.01770413, 0.02067643, 0.02028245, 0.19125476, 0.04556972, 0.0595558, 0.05315534, 0.05496629, 0.05356264, 0.04856596, 0.05327453, 0.05410472, 0.05597149, 0.05524866, 0.05181874, 0.05071663, 0.05204537, 0.0564108, 0.05518042, 0.01306714, 0.03341161, 0.03847246, 0.02810185, 0.02790166, 0.02920026, 0.02823597, 0.02631033, 0.0278531, 0.02880507, 0.02977769, 0.03145441, 0.02888389, 0.03280773, 0.03484927, 0.03049198, -0.00197727, 0.07534957, 0.04963879, 0.05530893, 0.05410828, 0.05252541, 0.05029899, 0.05321025, 0.05149245, 0.0511921, 0.04643495, 0.04604527, 0.04631618, 0.04404101, 0.04403536, 0.04499495, -0.02994183, -0.04787003, -0.01064558, -0.01779824, -0.01490502, -0.02157517, -0.0204778, -0.02180816, -0.01945375, -0.02062863, -0.02192209, -0.02520639, -0.02246656, -0.02427533, -0.02683363, -0.02762006, 0.08019473, -0.13005368, -0.07568636, -0.06082374, -0.06036175, -0.05875364, -0.05921887, -0.05869788, -0.05273941, -0.052565, -0.05346428, -0.05456541, -0.053657, -0.05656897, -0.05728589, -0.05321847, 0.16718403, -0.00390146, 0.0379406, 0.0356561, 0.03554131, 0.03924074, 0.03873615, 0.04187329, 0.04226924, 0.04378717, 0.04684274, 0.05117614, 0.04547792, 0.05251586, 0.05048339, 0.04950784, 0.09564418, 0.0547128, 0.08183969, 0.07978633, 0.08076023, 0.08108605, 0.08011818, 0.07965573, 0.08187773, 0.08350263, 0.08101469, 0.0786941, 0.0774442, 0.07724521, 0.07830418, 0.07599796, -0.04987567, 0.05923908, -0.01058746, -0.01177603, -0.01116162, -0.01364149, -0.01546014, -0.0117213, -0.01780043, -0.01648314, -0.02100247, -0.02104417, -0.02482123, -0.02611689, -0.02561143, -0.02597336, -0.05364667, 0.08211684, 0.04686937, 0.04605641, 0.04304186, 0.0397355, 0.03686767, 0.04087112, 0.03704741, 0.03706401, 0.03120073, 0.03349091, 0.03319963, 0.03205781, 0.03195127, 0.03180481, 0.16427967, -0.11048453, -0.04595276, -0.04982893, -0.05213465, -0.04809378, -0.05080318, -0.04992863, -0.04493337, -0.0467619, -0.04884703, -0.04627892, -0.04913311, -0.04955709, -0.04533982, -0.04570218, -0.10612928, -0.05121198, -0.06761009, -0.07251801, -0.07265285, -0.07417855, -0.07202412, -0.07499027, -0.07625481, -0.07535747, -0.07638787, -0.07920305, -0.07596069, -0.07959418, -0.08265036, -0.07955471, -0.16888915, 0.0753242, 0.04062594, 0.03375093, 0.03337452, 0.03699376, 0.03651138, 0.03611023, 0.03555622, 0.03378554, 0.0300498, 0.03395559, 0.02941847, 0.03156432, 0.03431173, 0.03016853, -0.03415358, -0.01699573, -0.04029295, -0.04912157, -0.0498858, -0.04917918, -0.04918056, -0.0525189, -0.05325506, -0.05341973, -0.04983329, -0.04883146, -0.04985548, -0.04736718, -0.0462027, -0.04836091, 0.02055675, 0.03419799, -0.02907669, -0.04350509, -0.04156144, -0.04234421, -0.04446109, -0.04461774, -0.04882839, -0.04822346, -0.04502493, -0.0506244, -0.05146913, -0.04655267, -0.04862994, -0.04841615, 0.20312774, -0.07208502, -0.03635615, -0.03556088, -0.04246174, -0.04195838, -0.04293778, -0.04071276, -0.04240569, -0.04125213, -0.04395144, -0.03959096, -0.04044993, -0.04015875, -0.04088107, -0.03885176]
+LATENTS_STD = [0.56700271, 0.65488982, 0.65589428, 0.66524369, 0.66619784, 0.6666382, 0.6720838, 0.66955978, 0.66928875, 0.67108786, 0.67092526, 0.67397463, 0.67894882, 0.67668313, 0.67769569, 0.67479557, 0.85245121, 0.8688373, 0.87348086, 0.88459337, 0.89135885, 0.8910504, 0.89714909, 0.89947474, 0.90201765, 0.90411824, 0.90692616, 0.90847772, 0.90648711, 0.91006982, 0.91033435, 0.90541548, 0.84960359, 0.85863352, 0.86895317, 0.88460612, 0.89245003, 0.89451706, 0.89931005, 0.90647358, 0.90338236, 0.90510076, 0.91008312, 0.90961218, 0.9123717, 0.91313171, 0.91435546, 0.91565102, 0.91877103, 0.85155135, 0.857804, 0.86998034, 0.87365264, 0.88161767, 0.88151032, 0.88758916, 0.89015514, 0.89245576, 0.89276224, 0.89450496, 0.90054202, 0.89994133, 0.90136105, 0.90114892, 0.77755755, 0.81456852, 0.81911844, 0.83137071, 0.83820474, 0.83890373, 0.84401101, 0.84425181, 0.84739357, 0.84798753, 0.85249585, 0.85114998, 0.85160935, 0.85626358, 0.85677862, 0.85641026, 0.69903517, 0.71697885, 0.71696913, 0.72583169, 0.72931731, 0.73254126, 0.73586977, 0.73734969, 0.73664582, 0.74084908, 0.74399322, 0.74471819, 0.74493188, 0.74824578, 0.75024873, 0.75274801, 0.8187142, 0.82251883, 0.82616025, 0.83164483, 0.84072375, 0.8396467, 0.84143305, 0.84880769, 0.8503468, 0.85196948, 0.85211051, 0.85386664, 0.85410017, 0.85439342, 0.85847849, 0.85385275, 0.67583984, 0.68259847, 0.69198853, 0.69928843, 0.70194328, 0.70467001, 0.70755547, 0.70917857, 0.71007699, 0.70963502, 0.71064079, 0.71027333, 0.71291167, 0.71537536, 0.71902508, 0.71604162, 0.72450989, 0.71979928, 0.72057378, 0.73035461, 0.73329622, 0.73660028, 0.73891461, 0.74279994, 0.74105692, 0.74002433, 0.74257588, 0.74416119, 0.74543899, 0.74694443, 0.74747062, 0.74586403, 0.90176988, 0.90990674, 0.91106802, 0.92163783, 0.92390233, 0.93056196, 0.93482202, 0.93642414, 0.93858379, 0.94064975, 0.94078934, 0.94325715, 0.94955301, 0.94814706, 0.95144123, 0.94923073, 0.49853548, 0.64968109, 0.6427654, 0.64966393, 0.6487664, 0.65203559, 0.6584242, 0.65351611, 0.65464371, 0.6574859, 0.65626335, 0.66123748, 0.66121179, 0.66077942, 0.66040152, 0.66474909, 0.61986589, 0.69138134, 0.6884557, 0.6955843, 0.69765401, 0.70015347, 0.70529598, 0.70468754, 0.70399523, 0.70479989, 0.70887572, 0.71126866, 0.7097227, 0.71249932, 0.71231949, 0.71175605, 0.35586974, 0.68723857, 0.68973219, 0.69958478, 0.6943453, 0.6995818, 0.70980215, 0.69899458, 0.70271689, 0.70095056, 0.69912851, 0.70522696, 0.70392174, 0.70916915, 0.70585734, 0.70373541, 0.98101336, 0.89024764, 0.89607251, 0.90678179, 0.91308665, 0.91812348, 0.91980827, 0.92480654, 0.92635667, 0.92887944, 0.93338072, 0.93468094, 0.93619436, 0.93906063, 0.94191772, 0.94471723, 0.83202779, 0.84106231, 0.84463632, 0.85829508, 0.86319661, 0.86751342, 0.86914337, 0.87085921, 0.87286359, 0.87537396, 0.87931138, 0.88054478, 0.8811838, 0.88872558, 0.88942474, 0.88934827, 0.44025335, 0.63061613, 0.63110614, 0.63601959, 0.6395812, 0.64104342, 0.65019929, 0.6502797, 0.64355946, 0.64657205, 0.64847094, 0.64728117, 0.64972943, 0.65162975, 0.65328044, 0.64914775]
+_WAVELETS = {
+ "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
+ "rearrange": torch.tensor([1.0, 1.0]),
+}
+# fmt: on
+
+
+class CosmosCausalConv3d(nn.Conv3d):
+ def __init__(
+ self,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ kernel_size: Union[int, Tuple[int, int, int]] = (3, 3, 3),
+ dilation: Union[int, Tuple[int, int, int]] = (1, 1, 1),
+ stride: Union[int, Tuple[int, int, int]] = (1, 1, 1),
+ padding: int = 1,
+ pad_mode: str = "constant",
+ ) -> None:
+ kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+ dilation = (dilation, dilation, dilation) if isinstance(dilation, int) else dilation
+ stride = (stride, stride, stride) if isinstance(stride, int) else stride
+
+ _, height_kernel_size, width_kernel_size = kernel_size
+ assert height_kernel_size % 2 == 1 and width_kernel_size % 2 == 1
+
+ super().__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ )
+
+ self.pad_mode = pad_mode
+ self.temporal_pad = dilation[0] * (kernel_size[0] - 1) + (1 - stride[0])
+ self.spatial_pad = (padding, padding, padding, padding)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states_prev = hidden_states[:, :, :1, ...].repeat(1, 1, self.temporal_pad, 1, 1)
+ hidden_states = torch.cat([hidden_states_prev, hidden_states], dim=2)
+ hidden_states = F.pad(hidden_states, (*self.spatial_pad, 0, 0), mode=self.pad_mode, value=0.0)
+ return super().forward(hidden_states)
+
+
+class CosmosCausalGroupNorm(torch.nn.Module):
+ def __init__(self, in_channels: int, num_groups: int = 1):
+ super().__init__()
+ self.norm = nn.GroupNorm(
+ num_groups=num_groups,
+ num_channels=in_channels,
+ eps=1e-6,
+ affine=True,
+ )
+ self.num_groups = num_groups
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.num_groups == 1:
+ batch_size = hidden_states.size(0)
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, T, H, W] -> [B * T, C, H, W]
+ hidden_states = self.norm(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
+ 0, 2, 1, 3, 4
+ ) # [B * T, C, H, W] -> [B, C, T, H, W]
+ else:
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+
+class CosmosPatchEmbed3d(nn.Module):
+ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_method = patch_method
+
+ wavelets = _WAVELETS.get(patch_method).clone()
+ arange = torch.arange(wavelets.shape[0])
+
+ self.register_buffer("wavelets", wavelets, persistent=False)
+ self.register_buffer("_arange", arange, persistent=False)
+
+ def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
+ dtype = hidden_states.dtype
+ wavelets = self.wavelets
+
+ n = wavelets.shape[0]
+ g = hidden_states.shape[1]
+ hl = wavelets.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
+ hh = (wavelets * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
+ hh = hh.to(dtype=dtype)
+ hl = hl.to(dtype=dtype)
+
+ # Handles temporal axis
+ hidden_states = F.pad(hidden_states, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(
+ dtype
+ )
+ xl = F.conv3d(hidden_states, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
+ xh = F.conv3d(hidden_states, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
+
+ # Handles spatial axes
+ xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
+ xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
+ xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
+ xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
+
+ xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+
+ hidden_states = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
+ if rescale:
+ hidden_states = hidden_states / 8**0.5
+ return hidden_states
+
+ def _haar(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ xi, xv = torch.split(hidden_states, [1, hidden_states.shape[2] - 1], dim=2)
+ hidden_states = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
+ for _ in range(int(math.log2(self.patch_size))):
+ hidden_states = self._dwt(hidden_states, rescale=True)
+ return hidden_states
+
+ def _arrange(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ xi, xv = torch.split(hidden_states, [1, hidden_states.shape[2] - 1], dim=2)
+ hidden_states = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p = self.patch_size
+
+ hidden_states = hidden_states.reshape(
+ batch_size, num_channels, num_frames // p, p, height // p, p, width // p, p
+ )
+ hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4).contiguous()
+ return hidden_states
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.patch_method == "haar":
+ return self._haar(hidden_states)
+ elif self.patch_method == "rearrange":
+ return self._arrange(hidden_states)
+ else:
+ raise ValueError(f"Unsupported patch method: {self.patch_method}")
+
+
+class CosmosUnpatcher3d(nn.Module):
+ def __init__(self, patch_size: int = 1, patch_method: str = "haar"):
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_method = patch_method
+
+ wavelets = _WAVELETS.get(patch_method).clone()
+ arange = torch.arange(wavelets.shape[0])
+
+ self.register_buffer("wavelets", wavelets, persistent=False)
+ self.register_buffer("_arange", arange, persistent=False)
+
+ def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
+ device = hidden_states.device
+ dtype = hidden_states.dtype
+ h = self.wavelets.to(device)
+
+ g = hidden_states.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
+ hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
+ hh = (h * ((-1) ** self._arange.to(device))).reshape(1, 1, -1).repeat(g, 1, 1)
+ hl = hl.to(dtype=dtype)
+ hh = hh.to(dtype=dtype)
+
+ xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(hidden_states, 8, dim=1)
+
+ # Handle height transposed convolutions
+ xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xll = F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xll
+
+ xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xlh = F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlh
+
+ xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xhl = F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhl
+
+ xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
+ xhh = F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhh
+
+ # Handles width transposed convolutions
+ xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
+ xl = F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xl
+ xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
+ xh = F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xh
+
+ # Handles time axis transposed convolutions
+ hidden_states = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
+ hidden_states = (
+ F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + hidden_states
+ )
+
+ if rescale:
+ hidden_states = hidden_states * 8**0.5
+
+ return hidden_states
+
+ def _ihaar(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for _ in range(int(math.log2(self.patch_size))):
+ hidden_states = self._idwt(hidden_states, rescale=True)
+ hidden_states = hidden_states[:, :, self.patch_size - 1 :, ...]
+ return hidden_states
+
+ def _irearrange(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ p = self.patch_size
+ hidden_states = hidden_states.unflatten(1, (-1, p, p, p))
+ hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4)
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+ hidden_states = hidden_states[:, :, p - 1 :, ...]
+ return hidden_states
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.patch_method == "haar":
+ return self._ihaar(hidden_states)
+ elif self.patch_method == "rearrange":
+ return self._irearrange(hidden_states)
+ else:
+ raise ValueError("Unknown patch method: " + self.patch_method)
+
+
+class CosmosConvProjection3d(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int) -> None:
+ super().__init__()
+
+ self.conv_s = CosmosCausalConv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1)
+ self.conv_t = CosmosCausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_s(hidden_states)
+ hidden_states = self.conv_t(hidden_states)
+ return hidden_states
+
+
+class CosmosResnetBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_groups: int = 1,
+ ) -> None:
+ super().__init__()
+ out_channels = out_channels or in_channels
+
+ self.norm1 = CosmosCausalGroupNorm(in_channels, num_groups)
+ self.conv1 = CosmosConvProjection3d(in_channels, out_channels)
+
+ self.norm2 = CosmosCausalGroupNorm(out_channels, num_groups)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = CosmosConvProjection3d(out_channels, out_channels)
+
+ if in_channels != out_channels:
+ self.conv_shortcut = CosmosCausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ else:
+ self.conv_shortcut = nn.Identity()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+ residual = self.conv_shortcut(residual)
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ return hidden_states + residual
+
+
+class CosmosDownsample3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ spatial_downsample: bool = True,
+ temporal_downsample: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.spatial_downsample = spatial_downsample
+ self.temporal_downsample = temporal_downsample
+
+ self.conv1 = nn.Identity()
+ self.conv2 = nn.Identity()
+ self.conv3 = nn.Identity()
+
+ if spatial_downsample:
+ self.conv1 = CosmosCausalConv3d(
+ in_channels, in_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=0
+ )
+ if temporal_downsample:
+ self.conv2 = CosmosCausalConv3d(
+ in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=0
+ )
+ if spatial_downsample or temporal_downsample:
+ self.conv3 = CosmosCausalConv3d(
+ in_channels, in_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=0
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if not self.spatial_downsample and not self.temporal_downsample:
+ return hidden_states
+
+ if self.spatial_downsample:
+ pad = (0, 1, 0, 1, 0, 0)
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
+ conv_out = self.conv1(hidden_states)
+ pool_out = F.avg_pool3d(hidden_states, kernel_size=(1, 2, 2), stride=(1, 2, 2))
+ hidden_states = conv_out + pool_out
+
+ if self.temporal_downsample:
+ hidden_states = torch.cat([hidden_states[:, :, :1, ...], hidden_states], dim=2)
+ conv_out = self.conv2(hidden_states)
+ pool_out = F.avg_pool3d(hidden_states, kernel_size=(2, 1, 1), stride=(2, 1, 1))
+ hidden_states = conv_out + pool_out
+
+ hidden_states = self.conv3(hidden_states)
+ return hidden_states
+
+
+class CosmosUpsample3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ spatial_upsample: bool = True,
+ temporal_upsample: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.spatial_upsample = spatial_upsample
+ self.temporal_upsample = temporal_upsample
+
+ self.conv1 = nn.Identity()
+ self.conv2 = nn.Identity()
+ self.conv3 = nn.Identity()
+
+ if temporal_upsample:
+ self.conv1 = CosmosCausalConv3d(
+ in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=0
+ )
+ if spatial_upsample:
+ self.conv2 = CosmosCausalConv3d(
+ in_channels, in_channels, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=1
+ )
+ if spatial_upsample or temporal_upsample:
+ self.conv3 = CosmosCausalConv3d(
+ in_channels, in_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=0
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if not self.spatial_upsample and not self.temporal_upsample:
+ return hidden_states
+
+ if self.temporal_upsample:
+ num_frames = hidden_states.size(2)
+ time_factor = int(1.0 + 1.0 * (num_frames > 1))
+ hidden_states = hidden_states.repeat_interleave(int(time_factor), dim=2)
+ hidden_states = hidden_states[..., time_factor - 1 :, :, :]
+ hidden_states = self.conv1(hidden_states) + hidden_states
+
+ if self.spatial_upsample:
+ hidden_states = hidden_states.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
+ hidden_states = self.conv2(hidden_states) + hidden_states
+
+ hidden_states = self.conv3(hidden_states)
+ return hidden_states
+
+
+class CosmosCausalAttention(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_groups: int = 1,
+ dropout: float = 0.0,
+ processor: Union["CosmosSpatialAttentionProcessor2_0", "CosmosTemporalAttentionProcessor2_0"] = None,
+ ) -> None:
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+
+ self.norm = CosmosCausalGroupNorm(attention_head_dim, num_groups=num_groups)
+ self.to_q = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
+ self.to_k = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
+ self.to_v = CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(
+ CosmosCausalConv3d(attention_head_dim, attention_head_dim, kernel_size=1, stride=1, padding=0)
+ )
+ self.to_out.append(nn.Dropout(dropout))
+
+ self.processor = processor
+ if self.processor is None:
+ raise ValueError("CosmosCausalAttention requires a processor.")
+
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ return self.processor(self, hidden_states=hidden_states, attention_mask=attention_mask)
+
+
+class CosmosSpatialAttentionProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "CosmosSpatialAttentionProcessor2_0 requires PyTorch 2.0 or higher. To use it, please upgrade PyTorch."
+ )
+
+ def __call__(
+ self, attn: CosmosCausalAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = attn.norm(hidden_states)
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # [B, C, T, H, W] -> [B * T, H * W, C]
+ query = query.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)
+ key = key.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)
+ value = value.permute(0, 2, 3, 4, 1).flatten(2, 3).flatten(0, 1)
+
+ # [B * T, H * W, C] -> [B * T, N, H * W, C // N]
+ query = query.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
+ hidden_states = hidden_states.unflatten(1, (height, width)).unflatten(0, (batch_size, num_frames))
+ hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states + residual
+
+
+class CosmosTemporalAttentionProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "CosmosSpatialAttentionProcessor2_0 requires PyTorch 2.0 or higher. To use it, please upgrade PyTorch."
+ )
+
+ def __call__(
+ self, attn: CosmosCausalAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = attn.norm(hidden_states)
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # [B, C, T, H, W] -> [B * T, H * W, C]
+ query = query.permute(0, 3, 4, 2, 1).flatten(0, 2)
+ key = key.permute(0, 3, 4, 2, 1).flatten(0, 2)
+ value = value.permute(0, 3, 4, 2, 1).flatten(0, 2)
+
+ # [B * T, H * W, C] -> [B * T, N, H * W, C // N]
+ query = query.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.num_attention_heads, -1)).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
+ hidden_states = hidden_states.unflatten(0, (batch_size, height, width))
+ hidden_states = hidden_states.permute(0, 4, 3, 1, 2)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states + residual
+
+
+class CosmosDownBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int,
+ dropout: float,
+ use_attention: bool,
+ use_downsample: bool,
+ spatial_downsample: bool,
+ temporal_downsample: bool,
+ ) -> None:
+ super().__init__()
+
+ resnets, attentions, temp_attentions = [], [], []
+ in_channel, out_channel = in_channels, out_channels
+
+ for _ in range(num_layers):
+ resnets.append(CosmosResnetBlock3d(in_channel, out_channel, dropout, num_groups=1))
+ in_channel = out_channel
+
+ if use_attention:
+ attentions.append(
+ CosmosCausalAttention(
+ num_attention_heads=1,
+ attention_head_dim=out_channel,
+ num_groups=1,
+ dropout=dropout,
+ processor=CosmosSpatialAttentionProcessor2_0(),
+ )
+ )
+ temp_attentions.append(
+ CosmosCausalAttention(
+ num_attention_heads=1,
+ attention_head_dim=out_channel,
+ num_groups=1,
+ dropout=dropout,
+ processor=CosmosTemporalAttentionProcessor2_0(),
+ )
+ )
+ else:
+ attentions.append(None)
+ temp_attentions.append(None)
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ self.downsamplers = None
+ if use_downsample:
+ self.downsamplers = nn.ModuleList([])
+ self.downsamplers.append(CosmosDownsample3d(out_channel, spatial_downsample, temporal_downsample))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for resnet, attention, temp_attention in zip(self.resnets, self.attentions, self.temp_attentions):
+ hidden_states = resnet(hidden_states)
+ if attention is not None:
+ hidden_states = attention(hidden_states)
+ if temp_attention is not None:
+ num_frames = hidden_states.size(2)
+ attention_mask = torch.tril(hidden_states.new_ones(num_frames, num_frames)).bool()
+ hidden_states = temp_attention(hidden_states, attention_mask)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class CosmosMidBlock3d(nn.Module):
+ def __init__(self, in_channels: int, num_layers: int, dropout: float, num_groups: int = 1) -> None:
+ super().__init__()
+
+ resnets, attentions, temp_attentions = [], [], []
+
+ resnets.append(CosmosResnetBlock3d(in_channels, in_channels, dropout, num_groups))
+ for _ in range(num_layers):
+ attentions.append(
+ CosmosCausalAttention(
+ num_attention_heads=1,
+ attention_head_dim=in_channels,
+ num_groups=num_groups,
+ dropout=dropout,
+ processor=CosmosSpatialAttentionProcessor2_0(),
+ )
+ )
+ temp_attentions.append(
+ CosmosCausalAttention(
+ num_attention_heads=1,
+ attention_head_dim=in_channels,
+ num_groups=num_groups,
+ dropout=dropout,
+ processor=CosmosTemporalAttentionProcessor2_0(),
+ )
+ )
+ resnets.append(CosmosResnetBlock3d(in_channels, in_channels, dropout, num_groups))
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.resnets[0](hidden_states)
+
+ for attention, temp_attention, resnet in zip(self.attentions, self.temp_attentions, self.resnets[1:]):
+ num_frames = hidden_states.size(2)
+ attention_mask = torch.tril(hidden_states.new_ones(num_frames, num_frames)).bool()
+
+ hidden_states = attention(hidden_states)
+ hidden_states = temp_attention(hidden_states, attention_mask)
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+
+class CosmosUpBlock3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int,
+ dropout: float,
+ use_attention: bool,
+ use_upsample: bool,
+ spatial_upsample: bool,
+ temporal_upsample: bool,
+ ) -> None:
+ super().__init__()
+
+ resnets, attention, temp_attentions = [], [], []
+ in_channel, out_channel = in_channels, out_channels
+
+ for _ in range(num_layers):
+ resnets.append(CosmosResnetBlock3d(in_channel, out_channel, dropout, num_groups=1))
+ in_channel = out_channel
+
+ if use_attention:
+ attention.append(
+ CosmosCausalAttention(
+ num_attention_heads=1,
+ attention_head_dim=out_channel,
+ num_groups=1,
+ dropout=dropout,
+ processor=CosmosSpatialAttentionProcessor2_0(),
+ )
+ )
+ temp_attentions.append(
+ CosmosCausalAttention(
+ num_attention_heads=1,
+ attention_head_dim=out_channel,
+ num_groups=1,
+ dropout=dropout,
+ processor=CosmosTemporalAttentionProcessor2_0(),
+ )
+ )
+ else:
+ attention.append(None)
+ temp_attentions.append(None)
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attention)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ self.upsamplers = None
+ if use_upsample:
+ self.upsamplers = nn.ModuleList([])
+ self.upsamplers.append(CosmosUpsample3d(out_channel, spatial_upsample, temporal_upsample))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for resnet, attention, temp_attention in zip(self.resnets, self.attentions, self.temp_attentions):
+ hidden_states = resnet(hidden_states)
+ if attention is not None:
+ hidden_states = attention(hidden_states)
+ if temp_attention is not None:
+ num_frames = hidden_states.size(2)
+ attention_mask = torch.tril(hidden_states.new_ones(num_frames, num_frames)).bool()
+ hidden_states = temp_attention(hidden_states, attention_mask)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CosmosEncoder3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 16,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ num_resnet_blocks: int = 2,
+ attention_resolutions: Tuple[int, ...] = (32,),
+ resolution: int = 1024,
+ patch_size: int = 4,
+ patch_type: str = "haar",
+ dropout: float = 0.0,
+ spatial_compression_ratio: int = 8,
+ temporal_compression_ratio: int = 8,
+ ) -> None:
+ super().__init__()
+ inner_dim = in_channels * patch_size**3
+ num_spatial_layers = int(math.log2(spatial_compression_ratio)) - int(math.log2(patch_size))
+ num_temporal_layers = int(math.log2(temporal_compression_ratio)) - int(math.log2(patch_size))
+
+ # 1. Input patching & projection
+ self.patch_embed = CosmosPatchEmbed3d(patch_size, patch_type)
+
+ self.conv_in = CosmosConvProjection3d(inner_dim, block_out_channels[0])
+
+ # 2. Down blocks
+ current_resolution = resolution // patch_size
+ down_blocks = []
+ for i in range(len(block_out_channels) - 1):
+ in_channel = block_out_channels[i]
+ out_channel = block_out_channels[i + 1]
+
+ use_attention = current_resolution in attention_resolutions
+ spatial_downsample = temporal_downsample = False
+ if i < len(block_out_channels) - 2:
+ use_downsample = True
+ spatial_downsample = i < num_spatial_layers
+ temporal_downsample = i < num_temporal_layers
+ current_resolution = current_resolution // 2
+ else:
+ use_downsample = False
+
+ down_blocks.append(
+ CosmosDownBlock3d(
+ in_channel,
+ out_channel,
+ num_resnet_blocks,
+ dropout,
+ use_attention,
+ use_downsample,
+ spatial_downsample,
+ temporal_downsample,
+ )
+ )
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ # 3. Mid block
+ self.mid_block = CosmosMidBlock3d(block_out_channels[-1], num_layers=1, dropout=dropout, num_groups=1)
+
+ # 4. Output norm & projection
+ self.norm_out = CosmosCausalGroupNorm(block_out_channels[-1], num_groups=1)
+ self.conv_out = CosmosConvProjection3d(block_out_channels[-1], out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.patch_embed(hidden_states)
+ hidden_states = self.conv_in(hidden_states)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.down_blocks:
+ hidden_states = self._gradient_checkpointing_func(block, hidden_states)
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+ else:
+ for block in self.down_blocks:
+ hidden_states = block(hidden_states)
+ hidden_states = self.mid_block(hidden_states)
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class CosmosDecoder3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ num_resnet_blocks: int = 2,
+ attention_resolutions: Tuple[int, ...] = (32,),
+ resolution: int = 1024,
+ patch_size: int = 4,
+ patch_type: str = "haar",
+ dropout: float = 0.0,
+ spatial_compression_ratio: int = 8,
+ temporal_compression_ratio: int = 8,
+ ) -> None:
+ super().__init__()
+ inner_dim = out_channels * patch_size**3
+ num_spatial_layers = int(math.log2(spatial_compression_ratio)) - int(math.log2(patch_size))
+ num_temporal_layers = int(math.log2(temporal_compression_ratio)) - int(math.log2(patch_size))
+ reversed_block_out_channels = list(reversed(block_out_channels))
+
+ # 1. Input projection
+ self.conv_in = CosmosConvProjection3d(in_channels, reversed_block_out_channels[0])
+
+ # 2. Mid block
+ self.mid_block = CosmosMidBlock3d(reversed_block_out_channels[0], num_layers=1, dropout=dropout, num_groups=1)
+
+ # 3. Up blocks
+ current_resolution = (resolution // patch_size) // 2 ** (len(block_out_channels) - 2)
+ up_blocks = []
+ for i in range(len(block_out_channels) - 1):
+ in_channel = reversed_block_out_channels[i]
+ out_channel = reversed_block_out_channels[i + 1]
+
+ use_attention = current_resolution in attention_resolutions
+ spatial_upsample = temporal_upsample = False
+ if i < len(block_out_channels) - 2:
+ use_upsample = True
+ temporal_upsample = 0 < i < num_temporal_layers + 1
+ spatial_upsample = temporal_upsample or (
+ i < num_spatial_layers and num_spatial_layers > num_temporal_layers
+ )
+ current_resolution = current_resolution * 2
+ else:
+ use_upsample = False
+
+ up_blocks.append(
+ CosmosUpBlock3d(
+ in_channel,
+ out_channel,
+ num_resnet_blocks + 1,
+ dropout,
+ use_attention,
+ use_upsample,
+ spatial_upsample,
+ temporal_upsample,
+ )
+ )
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ # 4. Output norm & projection & unpatching
+ self.norm_out = CosmosCausalGroupNorm(reversed_block_out_channels[-1], num_groups=1)
+ self.conv_out = CosmosConvProjection3d(reversed_block_out_channels[-1], inner_dim)
+
+ self.unpatch_embed = CosmosUnpatcher3d(patch_size, patch_type)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+ hidden_states = self.mid_block(hidden_states)
+
+ for block in self.up_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(block, hidden_states)
+ else:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ hidden_states = self.unpatch_embed(hidden_states)
+ return hidden_states
+
+
+class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
+ r"""
+ Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
+
+ Args:
+ in_channels (`int`, defaults to `3`):
+ Number of input channels.
+ out_channels (`int`, defaults to `3`):
+ Number of output channels.
+ latent_channels (`int`, defaults to `16`):
+ Number of latent channels.
+ encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
+ Number of output channels for each encoder down block.
+ decode_block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
+ Number of output channels for each decoder up block.
+ attention_resolutions (`Tuple[int, ...]`, defaults to `(32,)`):
+ List of image/video resolutions at which to apply attention.
+ resolution (`int`, defaults to `1024`):
+ Base image/video resolution used for computing whether a block should have attention layers.
+ num_layers (`int`, defaults to `2`):
+ Number of resnet blocks in each encoder/decoder block.
+ patch_size (`int`, defaults to `4`):
+ Patch size used for patching the input image/video.
+ patch_type (`str`, defaults to `haar`):
+ Patch type used for patching the input image/video. Can be either `haar` or `rearrange`.
+ scaling_factor (`float`, defaults to `1.0`):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. Not applicable in
+ Cosmos, but we default to 1.0 for consistency.
+ spatial_compression_ratio (`int`, defaults to `8`):
+ The spatial compression ratio to apply in the VAE. The number of downsample blocks is determined using
+ this.
+ temporal_compression_ratio (`int`, defaults to `8`):
+ The temporal compression ratio to apply in the VAE. The number of downsample blocks is determined using
+ this.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ latent_channels: int = 16,
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
+ decode_block_out_channels: Tuple[int, ...] = (256, 512, 512, 512),
+ attention_resolutions: Tuple[int, ...] = (32,),
+ resolution: int = 1024,
+ num_layers: int = 2,
+ patch_size: int = 4,
+ patch_type: str = "haar",
+ scaling_factor: float = 1.0,
+ spatial_compression_ratio: int = 8,
+ temporal_compression_ratio: int = 8,
+ latents_mean: Optional[List[float]] = LATENTS_MEAN,
+ latents_std: Optional[List[float]] = LATENTS_STD,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = CosmosEncoder3d(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ block_out_channels=encoder_block_out_channels,
+ num_resnet_blocks=num_layers,
+ attention_resolutions=attention_resolutions,
+ resolution=resolution,
+ patch_size=patch_size,
+ patch_type=patch_type,
+ spatial_compression_ratio=spatial_compression_ratio,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+ self.decoder = CosmosDecoder3d(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=decode_block_out_channels,
+ num_resnet_blocks=num_layers,
+ attention_resolutions=attention_resolutions,
+ resolution=resolution,
+ patch_size=patch_size,
+ patch_type=patch_type,
+ spatial_compression_ratio=spatial_compression_ratio,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+
+ self.quant_conv = CosmosCausalConv3d(latent_channels, latent_channels, kernel_size=1, padding=0)
+ self.post_quant_conv = CosmosCausalConv3d(latent_channels, latent_channels, kernel_size=1, padding=0)
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
+ # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered.
+ self.use_framewise_encoding = False
+ self.use_framewise_decoding = False
+
+ # This can be configured based on the amount of GPU memory available.
+ # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs.
+ # Setting it to higher values results in higher memory usage.
+ self.num_sample_frames_batch_size = 16
+ self.num_latent_frames_batch_size = 2
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 512
+ self.tile_sample_min_width = 512
+ self.tile_sample_min_num_frames = 16
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 448
+ self.tile_sample_stride_width = 448
+ self.tile_sample_stride_num_frames = 8
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_min_num_frames: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ tile_sample_stride_num_frames: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.encoder(x)
+ enc = self.quant_conv(x)
+ return enc
+
+ @apply_forward_hook
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = IdentityDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[Tuple[torch.Tensor], DecoderOutput]:
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
new file mode 100644
index 000000000000..3325d33c06bf
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
@@ -0,0 +1,488 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...loaders.single_file_model import FromOriginalModelMixin
+from ...utils import deprecate
+from ...utils.accelerate_utils import apply_forward_hook
+from ..attention import AttentionMixin
+from ..attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+ FusedAttnProcessor2_0,
+)
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+
+
+class AutoencoderKLFlux2(
+ ModelMixin, AutoencoderMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin
+):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
+ can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ mid_block_add_attention (`bool`, *optional*, default to `True`):
+ If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
+ mid_block will only have resnet blocks
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str, ...] = (
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ),
+ up_block_types: Tuple[str, ...] = (
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ ),
+ block_out_channels: Tuple[int, ...] = (
+ 128,
+ 256,
+ 512,
+ 512,
+ ),
+ layers_per_block: int = 2,
+ act_fn: str = "silu",
+ latent_channels: int = 32,
+ norm_num_groups: int = 32,
+ sample_size: int = 1024, # YiYi notes: not sure
+ force_upcast: bool = True,
+ use_quant_conv: bool = True,
+ use_post_quant_conv: bool = True,
+ mid_block_add_attention: bool = True,
+ batch_norm_eps: float = 1e-4,
+ batch_norm_momentum: float = 0.1,
+ patch_size: Tuple[int, int] = (2, 2),
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ mid_block_add_attention=mid_block_add_attention,
+ )
+
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
+
+ self.bn = nn.BatchNorm2d(
+ math.prod(patch_size) * latent_channels,
+ eps=batch_norm_eps,
+ momentum=batch_norm_momentum,
+ affine=False,
+ track_running_stats=True,
+ )
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # only relevant if vae tiling is enabled
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = (
+ self.config.sample_size[0]
+ if isinstance(self.config.sample_size, (list, tuple))
+ else self.config.sample_size
+ )
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
+ return self._tiled_encode(x)
+
+ enc = self.encoder(x)
+ if self.quant_conv is not None:
+ enc = self.quant_conv(enc)
+
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ if self.post_quant_conv is not None:
+ z = self.post_quant_conv(z)
+
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ if self.config.use_quant_conv:
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ enc = torch.cat(result_rows, dim=2)
+ return enc
+
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
+ `tuple` is returned.
+ """
+ deprecation_message = (
+ "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
+ "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
+ "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
+ )
+ deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
+
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ if self.config.use_quant_conv:
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ moments = torch.cat(result_rows, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ for j in range(0, z.shape[3], overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ if self.config.use_post_quant_conv:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+ > [!WARNING] > This API is 🧪 experimental.
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+ > [!WARNING] > This API is 🧪 experimental.
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
index 089e641d8852..ddc0aed6b0ff 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
@@ -27,7 +26,7 @@
from ..attention_processor import Attention
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -625,7 +624,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states
-class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
+class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
@@ -654,7 +653,7 @@ def __init__(
"HunyuanVideoUpBlock3D",
"HunyuanVideoUpBlock3D",
),
- block_out_channels: Tuple[int] = (128, 256, 512, 512),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
act_fn: str = "silu",
norm_num_groups: int = 32,
@@ -764,27 +763,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -829,7 +807,7 @@ def encode(
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
- tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
new file mode 100644
index 000000000000..616d0d415840
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
@@ -0,0 +1,709 @@
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanImageResnetBlock(nn.Module):
+ r"""
+ Residual block with two convolutions and optional channel change.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.nonlinearity = get_activation(non_linearity)
+
+ # layers
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if in_channels != out_channels:
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ else:
+ self.conv_shortcut = None
+
+ def forward(self, x):
+ # Apply shortcut connection
+ residual = x
+
+ # First normalization and activation
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+ x = self.conv2(x)
+
+ if self.conv_shortcut is not None:
+ x = self.conv_shortcut(x)
+ # Add residual connection
+ return x + residual
+
+
+class HunyuanImageAttentionBlock(nn.Module):
+ r"""
+ Self-attention with a single head.
+
+ Args:
+ in_channels (int): The number of channels in the input tensor.
+ """
+
+ def __init__(self, in_channels: int):
+ super().__init__()
+
+ # layers
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.to_q = nn.Conv2d(in_channels, in_channels, 1)
+ self.to_k = nn.Conv2d(in_channels, in_channels, 1)
+ self.to_v = nn.Conv2d(in_channels, in_channels, 1)
+ self.proj = nn.Conv2d(in_channels, in_channels, 1)
+
+ def forward(self, x):
+ identity = x
+ x = self.norm(x)
+
+ # compute query, key, value
+ query = self.to_q(x)
+ key = self.to_k(x)
+ value = self.to_v(x)
+
+ batch_size, channels, height, width = query.shape
+ query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
+ key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
+ value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
+
+ # apply attention
+ x = F.scaled_dot_product_attention(query, key, value)
+
+ x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
+ # output projection
+ x = self.proj(x)
+
+ return x + identity
+
+
+class HunyuanImageDownsample(nn.Module):
+ """
+ Downsampling block for spatial reduction.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ factor = 4
+ if out_channels % factor != 0:
+ raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
+
+ self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
+ self.group_size = factor * in_channels // out_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.conv(x)
+
+ B, C, H, W = h.shape
+ h = h.reshape(B, C, H // 2, 2, W // 2, 2)
+ h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
+ h = h.reshape(B, 4 * C, H // 2, W // 2)
+
+ B, C, H, W = x.shape
+ shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
+ shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
+ shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
+
+ B, C, H, W = shortcut.shape
+ shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
+ return h + shortcut
+
+
+class HunyuanImageUpsample(nn.Module):
+ """
+ Upsampling block for spatial expansion.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ factor = 4
+ self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
+ self.repeats = factor * out_channels // in_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.conv(x)
+
+ B, C, H, W = h.shape
+ h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
+ h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
+ h = h.reshape(B, C // 4, H * 2, W * 2)
+
+ shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
+
+ B, C, H, W = shortcut.shape
+ shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
+ shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
+ shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
+ return h + shortcut
+
+
+class HunyuanImageMidBlock(nn.Module):
+ """
+ Middle block for HunyuanImageVAE encoder and decoder.
+
+ Args:
+ in_channels (int): Number of input channels.
+ num_layers (int): Number of layers.
+ """
+
+ def __init__(self, in_channels: int, num_layers: int = 1):
+ super().__init__()
+
+ resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
+
+ attentions = []
+ for _ in range(num_layers):
+ attentions.append(HunyuanImageAttentionBlock(in_channels))
+ resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.resnets[0](x)
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ x = attn(x)
+ x = resnet(x)
+
+ return x
+
+
+class HunyuanImageEncoder2D(nn.Module):
+ r"""
+ Encoder network that compresses input to latent representation.
+
+ Args:
+ in_channels (int): Number of input channels.
+ z_channels (int): Number of latent channels.
+ block_out_channels (list of int): Output channels for each block.
+ num_res_blocks (int): Number of residual blocks per block.
+ spatial_compression_ratio (int): Spatial downsampling factor.
+ non_linearity (str): Type of non-linearity to use. Default is "silu".
+ downsample_match_channel (bool): Whether to match channels during downsampling.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ z_channels: int,
+ block_out_channels: Tuple[int, ...],
+ num_res_blocks: int,
+ spatial_compression_ratio: int,
+ non_linearity: str = "silu",
+ downsample_match_channel: bool = True,
+ ):
+ super().__init__()
+ if block_out_channels[-1] % (2 * z_channels) != 0:
+ raise ValueError(
+ f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
+ )
+
+ self.in_channels = in_channels
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.spatial_compression_ratio = spatial_compression_ratio
+
+ self.group_size = block_out_channels[-1] // (2 * z_channels)
+ self.nonlinearity = get_activation(non_linearity)
+
+ # init block
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ # downsample blocks
+ self.down_blocks = nn.ModuleList([])
+
+ block_in_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ block_out_channel = block_out_channels[i]
+ # residual blocks
+ for _ in range(num_res_blocks):
+ self.down_blocks.append(
+ HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
+ )
+ block_in_channel = block_out_channel
+
+ # downsample block
+ if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
+ if downsample_match_channel:
+ block_out_channel = block_out_channels[i + 1]
+ self.down_blocks.append(
+ HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
+ )
+ block_in_channel = block_out_channel
+
+ # middle blocks
+ self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
+
+ # output blocks
+ # Output layers
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv_in(x)
+
+ ## downsamples
+ for down_block in self.down_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ x = self._gradient_checkpointing_func(down_block, x)
+ else:
+ x = down_block(x)
+
+ ## middle
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ x = self._gradient_checkpointing_func(self.mid_block, x)
+ else:
+ x = self.mid_block(x)
+
+ ## head
+ B, C, H, W = x.shape
+ residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
+
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ x = self.conv_out(x)
+ return x + residual
+
+
+class HunyuanImageDecoder2D(nn.Module):
+ r"""
+ Decoder network that reconstructs output from latent representation.
+
+ Args:
+ z_channels : int
+ Number of latent channels.
+ out_channels : int
+ Number of output channels.
+ block_out_channels : Tuple[int, ...]
+ Output channels for each block.
+ num_res_blocks : int
+ Number of residual blocks per block.
+ spatial_compression_ratio : int
+ Spatial upsampling factor.
+ upsample_match_channel : bool
+ Whether to match channels during upsampling.
+ non_linearity (str): Type of non-linearity to use. Default is "silu".
+ """
+
+ def __init__(
+ self,
+ z_channels: int,
+ out_channels: int,
+ block_out_channels: Tuple[int, ...],
+ num_res_blocks: int,
+ spatial_compression_ratio: int,
+ upsample_match_channel: bool = True,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ if block_out_channels[0] % z_channels != 0:
+ raise ValueError(
+ f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
+ )
+
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.repeat = block_out_channels[0] // z_channels
+ self.spatial_compression_ratio = spatial_compression_ratio
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ # Middle blocks with attention
+ self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
+
+ # Upsampling blocks
+ block_in_channel = block_out_channels[0]
+ self.up_blocks = nn.ModuleList()
+ for i in range(len(block_out_channels)):
+ block_out_channel = block_out_channels[i]
+ for _ in range(self.num_res_blocks + 1):
+ self.up_blocks.append(
+ HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
+ )
+ block_in_channel = block_out_channel
+
+ if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
+ if upsample_match_channel:
+ block_out_channel = block_out_channels[i + 1]
+ self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
+ block_in_channel = block_out_channel
+
+ # Output layers
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ h = self._gradient_checkpointing_func(self.mid_block, h)
+ else:
+ h = self.mid_block(h)
+
+ for up_block in self.up_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ h = self._gradient_checkpointing_func(up_block, h)
+ else:
+ h = up_block(h)
+ h = self.norm_out(h)
+ h = self.nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model for 2D images with spatial tiling support.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = False
+
+ # fmt: off
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ latent_channels: int,
+ block_out_channels: Tuple[int, ...],
+ layers_per_block: int,
+ spatial_compression_ratio: int,
+ sample_size: int,
+ scaling_factor: float = None,
+ downsample_match_channel: bool = True,
+ upsample_match_channel: bool = True,
+ ) -> None:
+ # fmt: on
+ super().__init__()
+
+ self.encoder = HunyuanImageEncoder2D(
+ in_channels=in_channels,
+ z_channels=latent_channels,
+ block_out_channels=block_out_channels,
+ num_res_blocks=layers_per_block,
+ spatial_compression_ratio=spatial_compression_ratio,
+ downsample_match_channel=downsample_match_channel,
+ )
+
+ self.decoder = HunyuanImageDecoder2D(
+ z_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=list(reversed(block_out_channels)),
+ num_res_blocks=layers_per_block,
+ spatial_compression_ratio=spatial_compression_ratio,
+ upsample_match_channel=upsample_match_channel,
+ )
+
+ # Tiling and slicing configuration
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # Tiling parameters
+ self.tile_sample_min_size = sample_size
+ self.tile_latent_min_size = sample_size // spatial_compression_ratio
+ self.tile_overlap_factor = 0.25
+
+ def enable_tiling(
+ self,
+ tile_sample_min_size: Optional[int] = None,
+ tile_overlap_factor: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
+ to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
+ allow processing larger images.
+
+ Args:
+ tile_sample_min_size (`int`, *optional*):
+ The minimum size required for a sample to be separated into tiles across the spatial dimension.
+ tile_overlap_factor (`float`, *optional*):
+ The overlap factor required for a latent to be separated into tiles across the spatial dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
+ self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
+ self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor):
+
+ batch_size, num_channels, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
+ return self.tiled_encode(x)
+
+ enc = self.encoder(x)
+
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
+
+ batch_size, num_channels, height, width = z.shape
+
+ if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Encode input using spatial tiling strategy.
+
+ Args:
+ x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded images.
+ """
+ _, _, _, height, width = x.shape
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ rows = []
+ for i in range(0, height, overlap_size):
+ row = []
+ for j in range(0, width, overlap_size):
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ row.append(tile)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ moments = torch.cat(result_rows, dim=-2)
+
+ return moments
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Decode latent using spatial tiling strategy.
+
+ Args:
+ z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ _, _, height, width = z.shape
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ rows = []
+ for i in range(0, height, overlap_size):
+ row = []
+ for j in range(0, width, overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=-2)
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ posterior = self.encode(sample).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py
new file mode 100644
index 000000000000..2249063a9f00
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py
@@ -0,0 +1,934 @@
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanImageRefinerCausalConv3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ padding: Union[int, Tuple[int, int, int]] = 0,
+ dilation: Union[int, Tuple[int, int, int]] = 1,
+ bias: bool = True,
+ pad_mode: str = "replicate",
+ ) -> None:
+ super().__init__()
+
+ kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+
+ self.pad_mode = pad_mode
+ self.time_causal_padding = (
+ kernel_size[0] // 2,
+ kernel_size[0] // 2,
+ kernel_size[1] // 2,
+ kernel_size[1] // 2,
+ kernel_size[2] - 1,
+ 0,
+ )
+
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(hidden_states)
+
+
+class HunyuanImageRefinerRMS_norm(nn.Module):
+ r"""
+ A custom RMS normalization layer.
+
+ Args:
+ dim (int): The number of dimensions to normalize over.
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
+ Default is True.
+ images (bool, optional): Whether the input represents image data. Default is True.
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
+ """
+
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class HunyuanImageRefinerAttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
+
+ self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ identity = x
+
+ x = self.norm(x)
+
+ query = self.to_q(x)
+ key = self.to_k(x)
+ value = self.to_v(x)
+
+ batch_size, channels, frames, height, width = query.shape
+
+ query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+ key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+ value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+
+ x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
+
+ # batch_size, 1, frames * height * width, channels
+
+ x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
+ x = self.proj_out(x)
+
+ return x + identity
+
+
+class HunyuanImageRefinerUpsampleDCAE(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
+ super().__init__()
+ factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
+ self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
+
+ self.add_temporal_upsample = add_temporal_upsample
+ self.repeats = factor * out_channels // in_channels
+
+ @staticmethod
+ def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
+ """
+ Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
+
+ Args:
+ tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
+ r1: temporal upsampling factor
+ r2: height upsampling factor
+ r3: width upsampling factor
+ """
+ b, packed_c, f, h, w = tensor.shape
+ factor = r1 * r2 * r3
+ c = packed_c // factor
+
+ tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
+ tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ return tensor.reshape(b, c, f * r1, h * r2, w * r3)
+
+ def forward(self, x: torch.Tensor):
+ r1 = 2 if self.add_temporal_upsample else 1
+ h = self.conv(x)
+ if self.add_temporal_upsample:
+ h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
+ h = h[:, : h.shape[1] // 2]
+
+ # shortcut computation
+ shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
+ shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
+
+ else:
+ h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
+ shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
+ shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
+ return h + shortcut
+
+
+class HunyuanImageRefinerDownsampleDCAE(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
+ super().__init__()
+ factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
+ assert out_channels % factor == 0
+ # self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
+ self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
+
+ self.add_temporal_downsample = add_temporal_downsample
+ self.group_size = factor * in_channels // out_channels
+
+ @staticmethod
+ def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
+ """
+ Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
+
+ This packs spatial/temporal dimensions into channels (opposite of upsample)
+ """
+ b, c, packed_f, packed_h, packed_w = tensor.shape
+ f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
+
+ tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
+ tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
+
+ def forward(self, x: torch.Tensor):
+ r1 = 2 if self.add_temporal_downsample else 1
+ h = self.conv(x)
+ if self.add_temporal_downsample:
+ # h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
+ h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
+ h = torch.cat([h, h], dim=1)
+ # shortcut computation
+ # shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
+ shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
+ B, C, T, H, W = shortcut.shape
+ shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
+ else:
+ # h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
+ h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
+ # shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
+ shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
+ B, C, T, H, W = shortcut.shape
+ shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
+
+ return h + shortcut
+
+
+class HunyuanImageRefinerResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ non_linearity: str = "swish",
+ ) -> None:
+ super().__init__()
+ out_channels = out_channels or in_channels
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
+ self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
+
+ self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
+ self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
+
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ return hidden_states + residual
+
+
+class HunyuanImageRefinerMidBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ add_attention: bool = True,
+ ) -> None:
+ super().__init__()
+ self.add_attention = add_attention
+
+ # There is always at least one resnet
+ resnets = [
+ HunyuanImageRefinerResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
+ else:
+ attentions.append(None)
+
+ resnets.append(
+ HunyuanImageRefinerResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.resnets[0](hidden_states)
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanImageRefinerDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ downsample_out_channels: Optional[int] = None,
+ add_temporal_downsample: int = True,
+ ) -> None:
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ HunyuanImageRefinerResnetBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if downsample_out_channels is not None:
+ self.downsamplers = nn.ModuleList(
+ [
+ HunyuanImageRefinerDownsampleDCAE(
+ out_channels,
+ out_channels=downsample_out_channels,
+ add_temporal_downsample=add_temporal_downsample,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanImageRefinerUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ upsample_out_channels: Optional[int] = None,
+ add_temporal_upsample: bool = True,
+ ) -> None:
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ HunyuanImageRefinerResnetBlock(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if upsample_out_channels is not None:
+ self.upsamplers = nn.ModuleList(
+ [
+ HunyuanImageRefinerUpsampleDCAE(
+ out_channels,
+ out_channels=upsample_out_channels,
+ add_temporal_upsample=add_temporal_upsample,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for resnet in self.resnets:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
+
+ else:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanImageRefinerEncoder3D(nn.Module):
+ r"""
+ 3D vae encoder for HunyuanImageRefiner.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 64,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
+ layers_per_block: int = 2,
+ temporal_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 16,
+ downsample_match_channel: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.group_size = block_out_channels[-1] // self.out_channels
+
+ self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ input_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ add_spatial_downsample = i < np.log2(spatial_compression_ratio)
+ output_channel = block_out_channels[i]
+ if not add_spatial_downsample:
+ down_block = HunyuanImageRefinerDownBlock3D(
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ downsample_out_channels=None,
+ add_temporal_downsample=False,
+ )
+ input_channel = output_channel
+ else:
+ add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
+ downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
+ down_block = HunyuanImageRefinerDownBlock3D(
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ downsample_out_channels=downsample_out_channels,
+ add_temporal_downsample=add_temporal_downsample,
+ )
+ input_channel = downsample_out_channels
+
+ self.down_blocks.append(down_block)
+
+ self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
+
+ self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
+ self.conv_act = nn.SiLU()
+ self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for down_block in self.down_blocks:
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
+
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+ else:
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states)
+
+ hidden_states = self.mid_block(hidden_states)
+
+ # short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
+ batch_size, _, frame, height, width = hidden_states.shape
+ short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ hidden_states += short_cut
+
+ return hidden_states
+
+
+class HunyuanImageRefinerDecoder3D(nn.Module):
+ r"""
+ Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 32,
+ out_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
+ layers_per_block: int = 2,
+ spatial_compression_ratio: int = 16,
+ temporal_compression_ratio: int = 4,
+ upsample_match_channel: bool = True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.repeat = block_out_channels[0] // self.in_channels
+
+ self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
+
+ # up
+ input_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ output_channel = block_out_channels[i]
+
+ add_spatial_upsample = i < np.log2(spatial_compression_ratio)
+ add_temporal_upsample = i < np.log2(temporal_compression_ratio)
+ if add_spatial_upsample or add_temporal_upsample:
+ upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
+ up_block = HunyuanImageRefinerUpBlock3D(
+ num_layers=self.layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ upsample_out_channels=upsample_out_channels,
+ add_temporal_upsample=add_temporal_upsample,
+ )
+ input_channel = upsample_out_channels
+ else:
+ up_block = HunyuanImageRefinerUpBlock3D(
+ num_layers=self.layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ upsample_out_channels=None,
+ add_temporal_upsample=False,
+ )
+ input_channel = output_channel
+
+ self.up_blocks.append(up_block)
+
+ # out
+ self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
+ self.conv_act = nn.SiLU()
+ self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+
+ for up_block in self.up_blocks:
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
+ else:
+ hidden_states = self.mid_block(hidden_states)
+
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states)
+
+ # post-process
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
+ HunyuanImage-2.1 Refiner.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ latent_channels: int = 32,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
+ layers_per_block: int = 2,
+ spatial_compression_ratio: int = 16,
+ temporal_compression_ratio: int = 4,
+ downsample_match_channel: bool = True,
+ upsample_match_channel: bool = True,
+ scaling_factor: float = 1.03682,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = HunyuanImageRefinerEncoder3D(
+ in_channels=in_channels,
+ out_channels=latent_channels * 2,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ temporal_compression_ratio=temporal_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ downsample_match_channel=downsample_match_channel,
+ )
+
+ self.decoder = HunyuanImageRefinerDecoder3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=list(reversed(block_out_channels)),
+ layers_per_block=layers_per_block,
+ temporal_compression_ratio=temporal_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ upsample_match_channel=upsample_match_channel,
+ )
+
+ self.spatial_compression_ratio = spatial_compression_ratio
+ self.temporal_compression_ratio = temporal_compression_ratio
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+
+ self.tile_overlap_factor = 0.25
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ tile_overlap_factor: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ _, _, _, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ x = self.encoder(x)
+ return x
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
+ _, _, _, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z)
+
+ dec = self.decoder(z)
+
+ return dec
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z)
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ _, _, _, height, width = x.shape
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
+ overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
+ blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
+ blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
+ row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
+ row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
+
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(tile)
+ row.append(tile)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+ moments = torch.cat(result_rows, dim=-2)
+
+ return moments
+
+ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+
+ _, _, _, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
+ overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
+ blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
+ blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
+ row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
+ row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
+
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ tile = z[
+ :,
+ :,
+ :,
+ i : i + tile_latent_min_height,
+ j : j + tile_latent_min_width,
+ ]
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+ dec = torch.cat(result_rows, dim=-2)
+
+ return dec
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
new file mode 100644
index 000000000000..4b1beb74a3bc
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py
@@ -0,0 +1,967 @@
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanVideo15CausalConv3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ padding: Union[int, Tuple[int, int, int]] = 0,
+ dilation: Union[int, Tuple[int, int, int]] = 1,
+ bias: bool = True,
+ pad_mode: str = "replicate",
+ ) -> None:
+ super().__init__()
+
+ kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+
+ self.pad_mode = pad_mode
+ self.time_causal_padding = (
+ kernel_size[0] // 2,
+ kernel_size[0] // 2,
+ kernel_size[1] // 2,
+ kernel_size[1] // 2,
+ kernel_size[2] - 1,
+ 0,
+ )
+
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(hidden_states)
+
+
+class HunyuanVideo15RMS_norm(nn.Module):
+ r"""
+ A custom RMS normalization layer.
+
+ Args:
+ dim (int): The number of dimensions to normalize over.
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
+ Default is True.
+ images (bool, optional): Whether the input represents image data. Default is True.
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
+ """
+
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class HunyuanVideo15AttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = HunyuanVideo15RMS_norm(in_channels, images=False)
+
+ self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+
+ @staticmethod
+ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
+ """Prepare a causal attention mask for 3D videos.
+
+ Args:
+ n_frame (int): Number of frames (temporal length).
+ n_hw (int): Product of height and width.
+ dtype: Desired mask dtype.
+ device: Device for the mask.
+ batch_size (int, optional): If set, expands for batch.
+
+ Returns:
+ torch.Tensor: Causal attention mask.
+ """
+ seq_len = n_frame * n_hw
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
+ for i in range(seq_len):
+ i_frame = i // n_hw
+ mask[i, : (i_frame + 1) * n_hw] = 0
+ if batch_size is not None:
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
+ return mask
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ identity = x
+
+ x = self.norm(x)
+
+ query = self.to_q(x)
+ key = self.to_k(x)
+ value = self.to_v(x)
+
+ batch_size, channels, frames, height, width = query.shape
+
+ query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+ key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+ value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+
+ attention_mask = self.prepare_causal_attention_mask(
+ frames, height * width, query.dtype, query.device, batch_size=batch_size
+ )
+
+ x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
+
+ # batch_size, 1, frames * height * width, channels
+
+ x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
+ x = self.proj_out(x)
+
+ return x + identity
+
+
+class HunyuanVideo15Upsample(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
+ super().__init__()
+ factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
+ self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels * factor, kernel_size=3)
+
+ self.add_temporal_upsample = add_temporal_upsample
+ self.repeats = factor * out_channels // in_channels
+
+ @staticmethod
+ def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
+ """
+ Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
+
+ Args:
+ tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
+ r1: temporal upsampling factor
+ r2: height upsampling factor
+ r3: width upsampling factor
+ """
+ b, packed_c, f, h, w = tensor.shape
+ factor = r1 * r2 * r3
+ c = packed_c // factor
+
+ tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
+ tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ return tensor.reshape(b, c, f * r1, h * r2, w * r3)
+
+ def forward(self, x: torch.Tensor):
+ r1 = 2 if self.add_temporal_upsample else 1
+ h = self.conv(x)
+ if self.add_temporal_upsample:
+ h_first = h[:, :, :1, :, :]
+ h_first = self._dcae_upsample_rearrange(h_first, r1=1, r2=2, r3=2)
+ h_first = h_first[:, : h_first.shape[1] // 2]
+ h_next = h[:, :, 1:, :, :]
+ h_next = self._dcae_upsample_rearrange(h_next, r1=r1, r2=2, r3=2)
+ h = torch.cat([h_first, h_next], dim=2)
+
+ # shortcut computation
+ x_first = x[:, :, :1, :, :]
+ x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2)
+ x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1)
+
+ x_next = x[:, :, 1:, :, :]
+ x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2)
+ x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1)
+ shortcut = torch.cat([x_first, x_next], dim=2)
+
+ else:
+ h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
+ shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
+ shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
+ return h + shortcut
+
+
+class HunyuanVideo15Downsample(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
+ super().__init__()
+ factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
+ self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
+
+ self.add_temporal_downsample = add_temporal_downsample
+ self.group_size = factor * in_channels // out_channels
+
+ @staticmethod
+ def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
+ """
+ Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
+
+ This packs spatial/temporal dimensions into channels (opposite of upsample)
+ """
+ b, c, packed_f, packed_h, packed_w = tensor.shape
+ f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
+
+ tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
+ tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
+
+ def forward(self, x: torch.Tensor):
+ r1 = 2 if self.add_temporal_downsample else 1
+ h = self.conv(x)
+ if self.add_temporal_downsample:
+ h_first = h[:, :, :1, :, :]
+ h_first = self._dcae_downsample_rearrange(h_first, r1=1, r2=2, r3=2)
+ h_first = torch.cat([h_first, h_first], dim=1)
+ h_next = h[:, :, 1:, :, :]
+ h_next = self._dcae_downsample_rearrange(h_next, r1=r1, r2=2, r3=2)
+ h = torch.cat([h_first, h_next], dim=2)
+
+ # shortcut computation
+ x_first = x[:, :, :1, :, :]
+ x_first = self._dcae_downsample_rearrange(x_first, r1=1, r2=2, r3=2)
+ B, C, T, H, W = x_first.shape
+ x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
+ x_next = x[:, :, 1:, :, :]
+ x_next = self._dcae_downsample_rearrange(x_next, r1=r1, r2=2, r3=2)
+ B, C, T, H, W = x_next.shape
+ x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
+ shortcut = torch.cat([x_first, x_next], dim=2)
+ else:
+ h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
+ shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
+ B, C, T, H, W = shortcut.shape
+ shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
+
+ return h + shortcut
+
+
+class HunyuanVideo15ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ non_linearity: str = "swish",
+ ) -> None:
+ super().__init__()
+ out_channels = out_channels or in_channels
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.norm1 = HunyuanVideo15RMS_norm(in_channels, images=False)
+ self.conv1 = HunyuanVideo15CausalConv3d(in_channels, out_channels, kernel_size=3)
+
+ self.norm2 = HunyuanVideo15RMS_norm(out_channels, images=False)
+ self.conv2 = HunyuanVideo15CausalConv3d(out_channels, out_channels, kernel_size=3)
+
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ return hidden_states + residual
+
+
+class HunyuanVideo15MidBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ add_attention: bool = True,
+ ) -> None:
+ super().__init__()
+ self.add_attention = add_attention
+
+ # There is always at least one resnet
+ resnets = [
+ HunyuanVideo15ResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(HunyuanVideo15AttnBlock(in_channels))
+ else:
+ attentions.append(None)
+
+ resnets.append(
+ HunyuanVideo15ResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.resnets[0](hidden_states)
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideo15DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ downsample_out_channels: Optional[int] = None,
+ add_temporal_downsample: int = True,
+ ) -> None:
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ HunyuanVideo15ResnetBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if downsample_out_channels is not None:
+ self.downsamplers = nn.ModuleList(
+ [
+ HunyuanVideo15Downsample(
+ out_channels,
+ out_channels=downsample_out_channels,
+ add_temporal_downsample=add_temporal_downsample,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideo15UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ upsample_out_channels: Optional[int] = None,
+ add_temporal_upsample: bool = True,
+ ) -> None:
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ HunyuanVideo15ResnetBlock(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if upsample_out_channels is not None:
+ self.upsamplers = nn.ModuleList(
+ [
+ HunyuanVideo15Upsample(
+ out_channels,
+ out_channels=upsample_out_channels,
+ add_temporal_upsample=add_temporal_upsample,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for resnet in self.resnets:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
+
+ else:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanVideo15Encoder3D(nn.Module):
+ r"""
+ 3D vae encoder for HunyuanImageRefiner.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 64,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
+ layers_per_block: int = 2,
+ temporal_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 16,
+ downsample_match_channel: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.group_size = block_out_channels[-1] // self.out_channels
+
+ self.conv_in = HunyuanVideo15CausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ input_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ add_spatial_downsample = i < np.log2(spatial_compression_ratio)
+ output_channel = block_out_channels[i]
+ if not add_spatial_downsample:
+ down_block = HunyuanVideo15DownBlock3D(
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ downsample_out_channels=None,
+ add_temporal_downsample=False,
+ )
+ input_channel = output_channel
+ else:
+ add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
+ downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
+ down_block = HunyuanVideo15DownBlock3D(
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ downsample_out_channels=downsample_out_channels,
+ add_temporal_downsample=add_temporal_downsample,
+ )
+ input_channel = downsample_out_channels
+
+ self.down_blocks.append(down_block)
+
+ self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[-1])
+
+ self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False)
+ self.conv_act = nn.SiLU()
+ self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for down_block in self.down_blocks:
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
+
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+ else:
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states)
+
+ hidden_states = self.mid_block(hidden_states)
+
+ batch_size, _, frame, height, width = hidden_states.shape
+ short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ hidden_states += short_cut
+
+ return hidden_states
+
+
+class HunyuanVideo15Decoder3D(nn.Module):
+ r"""
+ Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 32,
+ out_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
+ layers_per_block: int = 2,
+ spatial_compression_ratio: int = 16,
+ temporal_compression_ratio: int = 4,
+ upsample_match_channel: bool = True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.repeat = block_out_channels[0] // self.in_channels
+
+ self.conv_in = HunyuanVideo15CausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[0])
+
+ # up
+ input_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ output_channel = block_out_channels[i]
+
+ add_spatial_upsample = i < np.log2(spatial_compression_ratio)
+ add_temporal_upsample = i < np.log2(temporal_compression_ratio)
+ if add_spatial_upsample or add_temporal_upsample:
+ upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
+ up_block = HunyuanVideo15UpBlock3D(
+ num_layers=self.layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ upsample_out_channels=upsample_out_channels,
+ add_temporal_upsample=add_temporal_upsample,
+ )
+ input_channel = upsample_out_channels
+ else:
+ up_block = HunyuanVideo15UpBlock3D(
+ num_layers=self.layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ upsample_out_channels=None,
+ add_temporal_upsample=False,
+ )
+ input_channel = output_channel
+
+ self.up_blocks.append(up_block)
+
+ # out
+ self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False)
+ self.conv_act = nn.SiLU()
+ self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+
+ for up_block in self.up_blocks:
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
+ else:
+ hidden_states = self.mid_block(hidden_states)
+
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states)
+
+ # post-process
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
+ HunyuanVideo-1.5.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ latent_channels: int = 32,
+ block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
+ layers_per_block: int = 2,
+ spatial_compression_ratio: int = 16,
+ temporal_compression_ratio: int = 4,
+ downsample_match_channel: bool = True,
+ upsample_match_channel: bool = True,
+ scaling_factor: float = 1.03682,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = HunyuanVideo15Encoder3D(
+ in_channels=in_channels,
+ out_channels=latent_channels * 2,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ temporal_compression_ratio=temporal_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ downsample_match_channel=downsample_match_channel,
+ )
+
+ self.decoder = HunyuanVideo15Decoder3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=list(reversed(block_out_channels)),
+ layers_per_block=layers_per_block,
+ temporal_compression_ratio=temporal_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ upsample_match_channel=upsample_match_channel,
+ )
+
+ self.spatial_compression_ratio = spatial_compression_ratio
+ self.temporal_compression_ratio = temporal_compression_ratio
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal tile height and width in latent space
+ self.tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio
+ self.tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio
+ self.tile_overlap_factor = 0.25
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_latent_min_height: Optional[int] = None,
+ tile_latent_min_width: Optional[int] = None,
+ tile_overlap_factor: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_latent_min_height (`int`, *optional*):
+ The minimum height required for a latent to be separated into tiles across the height dimension.
+ tile_latent_min_width (`int`, *optional*):
+ The minimum width required for a latent to be separated into tiles across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_latent_min_height = tile_latent_min_height or self.tile_latent_min_height
+ self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
+ self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ _, _, _, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ x = self.encoder(x)
+ return x
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
+ _, _, _, height, width = z.shape
+
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
+ return self.tiled_decode(z)
+
+ dec = self.decoder(z)
+
+ return dec
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z)
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ _, _, _, height, width = x.shape
+
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
+ blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
+ blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
+ row_limit_height = self.tile_latent_min_height - blend_height # 8 - 2 = 6
+ row_limit_width = self.tile_latent_min_width - blend_width # 8 - 2 = 6
+
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(tile)
+ row.append(tile)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+ moments = torch.cat(result_rows, dim=-2)
+
+ return moments
+
+ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+
+ _, _, _, height, width = z.shape
+
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
+ blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
+ blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
+ row_limit_height = self.tile_sample_min_height - blend_height # 256 - 64 = 192
+ row_limit_width = self.tile_sample_min_width - blend_width # 256 - 64 = 192
+
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ tile = z[
+ :,
+ :,
+ :,
+ i : i + self.tile_latent_min_height,
+ j : j + self.tile_latent_min_width,
+ ]
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+ dec = torch.cat(result_rows, dim=-2)
+
+ return dec
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
index 2b2f77a5509d..47f2081b7e45 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The Lightricks team and The HuggingFace Team.
+# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -26,7 +26,7 @@
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class LTXVideoCausalConv3d(nn.Module):
@@ -1034,7 +1034,7 @@ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = No
return hidden_states
-class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -1067,7 +1067,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
encoder_causal (`bool`, defaults to `True`):
Whether the encoder should behave causally (future frames depend only on past frames) or not.
decoder_causal (`bool`, defaults to `False`):
@@ -1219,27 +1219,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
@@ -1285,7 +1264,7 @@ def _decode(
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
- tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py b/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py
index 7b53192033dc..97ca9d669264 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py
@@ -26,7 +26,7 @@
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -428,7 +428,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class EasyAnimateEncoder(nn.Module):
r"""
- Causal encoder for 3D video-like data used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
+ Causal encoder for 3D video-like data used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
"""
_supports_gradient_checkpointing = True
@@ -544,7 +544,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class EasyAnimateDecoder(nn.Module):
r"""
- Causal decoder for 3D video-like data used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
+ Causal decoder for 3D video-like data used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
"""
_supports_gradient_checkpointing = True
@@ -663,10 +663,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states
-class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
+class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
- model is used in [EasyAnimate](https://arxiv.org/abs/2405.18991).
+ model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
@@ -805,27 +805,6 @@ def enable_tiling(
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
@apply_forward_hook
def _encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -887,7 +866,7 @@ def encode(
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
- tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (z.shape[-1] > tile_latent_min_height or z.shape[-2] > tile_latent_min_width):
return self.tiled_decode(z, return_dict=return_dict)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
index d69ec6252b00..7a64ac7de172 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The Mochi team and The HuggingFace Team.
+# Copyright 2025 The Mochi team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -27,7 +27,7 @@
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -210,7 +210,7 @@ def forward(
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
- conv_cache=conv_cache.get(conv_cache_key),
+ conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -306,7 +306,7 @@ def forward(
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
- resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ resnet, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -382,7 +382,7 @@ def forward(
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
- conv_cache=conv_cache.get(conv_cache_key),
+ conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -497,6 +497,8 @@ def __init__(
self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)
+ self.gradient_checkpointing = False
+
def forward(
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
@@ -513,13 +515,13 @@ def forward(
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
- self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
+ self.block_in, hidden_states, conv_cache.get("block_in")
)
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
- down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ down_block, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
@@ -623,13 +625,13 @@ def forward(
# 1. Mid
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
- self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
+ self.block_in, hidden_states, conv_cache.get("block_in")
)
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
- up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
+ up_block, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
@@ -655,7 +657,7 @@ def forward(
return hidden_states, new_conv_cache
-class AutoencoderKLMochi(ModelMixin, ConfigMixin):
+class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[Mochi 1 preview](https://github.com/genmoai/models).
@@ -675,7 +677,7 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
"""
_supports_gradient_checkpointing = True
@@ -686,8 +688,8 @@ def __init__(
self,
in_channels: int = 15,
out_channels: int = 3,
- encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
- decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
+ encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
latent_channels: int = 12,
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
act_fn: str = "silu",
@@ -816,27 +818,6 @@ def enable_tiling(
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _enable_framewise_encoding(self):
r"""
Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
@@ -907,7 +888,7 @@ def encode(
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
- tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
new file mode 100644
index 000000000000..618801dfb605
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
@@ -0,0 +1,1048 @@
+# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# We gratefully acknowledge the Wan Team for their outstanding contributions.
+# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
+# For more information about the Wan VAE, please refer to:
+# - GitHub: https://github.com/Wan-Video/Wan2.1
+# - Paper: https://huggingface.co/papers/2503.20314
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+CACHE_T = 2
+
+
+class QwenImageCausalConv3d(nn.Conv3d):
+ r"""
+ A custom 3D causal convolution layer with feature caching support.
+
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
+ caching for efficient inference.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ padding: Union[int, Tuple[int, int, int]] = 0,
+ ) -> None:
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+
+ # Set up causal padding
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+ return super().forward(x)
+
+
+class QwenImageRMS_norm(nn.Module):
+ r"""
+ A custom RMS normalization layer.
+
+ Args:
+ dim (int): The number of dimensions to normalize over.
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
+ Default is True.
+ images (bool, optional): Whether the input represents image data. Default is True.
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
+ """
+
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class QwenImageUpsample(nn.Upsample):
+ r"""
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
+
+ Args:
+ x (torch.Tensor): Input tensor to be upsampled.
+
+ Returns:
+ torch.Tensor: Upsampled tensor with the same data type as the input.
+ """
+
+ def forward(self, x):
+ return super().forward(x.float()).type_as(x)
+
+
+class QwenImageResample(nn.Module):
+ r"""
+ A custom resampling module for 2D and 3D data.
+
+ Args:
+ dim (int): The number of input/output channels.
+ mode (str): The resampling mode. Must be one of:
+ - 'none': No resampling (identity operation).
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
+ """
+
+ def __init__(self, dim: int, mode: str) -> None:
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == "upsample2d":
+ self.resample = nn.Sequential(
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
+ )
+ elif mode == "upsample3d":
+ self.resample = nn.Sequential(
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
+ )
+ self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == "downsample2d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == "downsample3d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == "upsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = "Rep"
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
+ # cache last frame of last two chunk
+ cache_x = torch.cat(
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
+ )
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
+ if feat_cache[idx] == "Rep":
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ x = self.resample(x)
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
+
+ if self.mode == "downsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -1:, :, :].clone()
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+
+class QwenImageResidualBlock(nn.Module):
+ r"""
+ A custom residual block module.
+
+ Args:
+ in_dim (int): Number of input channels.
+ out_dim (int): Number of output channels.
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ dropout: float = 0.0,
+ non_linearity: str = "silu",
+ ) -> None:
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.nonlinearity = get_activation(non_linearity)
+
+ # layers
+ self.norm1 = QwenImageRMS_norm(in_dim, images=False)
+ self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
+ self.norm2 = QwenImageRMS_norm(out_dim, images=False)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
+ self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ # Apply shortcut connection
+ h = self.conv_shortcut(x)
+
+ # First normalization and activation
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ # Second normalization and activation
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+
+ # Dropout
+ x = self.dropout(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv2(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv2(x)
+
+ # Add residual connection
+ return x + h
+
+
+class QwenImageAttentionBlock(nn.Module):
+ r"""
+ Causal self-attention with a single head.
+
+ Args:
+ dim (int): The number of channels in the input tensor.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = QwenImageRMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, x):
+ identity = x
+ batch_size, channels, time, height, width = x.size()
+
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
+ x = self.norm(x)
+
+ # compute query, key, value
+ qkv = self.to_qkv(x)
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
+ q, k, v = qkv.chunk(3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(q, k, v)
+
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
+
+ # output projection
+ x = self.proj(x)
+
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
+ x = x.view(batch_size, time, channels, height, width)
+ x = x.permute(0, 2, 1, 3, 4)
+
+ return x + identity
+
+
+class QwenImageMidBlock(nn.Module):
+ """
+ Middle block for QwenImageVAE encoder and decoder.
+
+ Args:
+ dim (int): Number of input/output channels.
+ dropout (float): Dropout rate.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
+ super().__init__()
+ self.dim = dim
+
+ # Create the components
+ resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
+ attentions = []
+ for _ in range(num_layers):
+ attentions.append(QwenImageAttentionBlock(dim))
+ resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ # First residual block
+ x = self.resnets[0](x, feat_cache, feat_idx)
+
+ # Process through attention and residual blocks
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ x = attn(x)
+
+ x = resnet(x, feat_cache, feat_idx)
+
+ return x
+
+
+class QwenImageEncoder3d(nn.Module):
+ r"""
+ A 3D encoder module.
+
+ Args:
+ dim (int): The base number of channels in the first layer.
+ z_dim (int): The dimensionality of the latent space.
+ dim_mult (list of int): Multipliers for the number of channels in each block.
+ num_res_blocks (int): Number of residual blocks in each block.
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
+ dropout (float): Dropout rate for the dropout layers.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.nonlinearity = get_activation(non_linearity)
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ self.down_blocks = nn.ModuleList([])
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ self.down_blocks.append(QwenImageAttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
+ self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
+ scale /= 2.0
+
+ # middle blocks
+ self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
+
+ # output blocks
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
+ self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_in(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_in(x)
+
+ ## downsamples
+ for layer in self.down_blocks:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ x = self.mid_block(x, feat_cache, feat_idx)
+
+ ## head
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_out(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_out(x)
+ return x
+
+
+class QwenImageUpBlock(nn.Module):
+ """
+ A block that handles upsampling for the QwenImageVAE decoder.
+
+ Args:
+ in_dim (int): Input dimension
+ out_dim (int): Output dimension
+ num_res_blocks (int): Number of residual blocks
+ dropout (float): Dropout rate
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
+ non_linearity (str): Type of non-linearity to use
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ num_res_blocks: int,
+ dropout: float = 0.0,
+ upsample_mode: Optional[str] = None,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # Create layers list
+ resnets = []
+ # Add residual blocks and attention if needed
+ current_dim = in_dim
+ for _ in range(num_res_blocks + 1):
+ resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
+ current_dim = out_dim
+
+ self.resnets = nn.ModuleList(resnets)
+
+ # Add upsampling layer if needed
+ self.upsamplers = None
+ if upsample_mode is not None:
+ self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ """
+ Forward pass through the upsampling block.
+
+ Args:
+ x (torch.Tensor): Input tensor
+ feat_cache (list, optional): Feature cache for causal convolutions
+ feat_idx (list, optional): Feature index for cache management
+
+ Returns:
+ torch.Tensor: Output tensor
+ """
+ for resnet in self.resnets:
+ if feat_cache is not None:
+ x = resnet(x, feat_cache, feat_idx)
+ else:
+ x = resnet(x)
+
+ if self.upsamplers is not None:
+ if feat_cache is not None:
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
+ else:
+ x = self.upsamplers[0](x)
+ return x
+
+
+class QwenImageDecoder3d(nn.Module):
+ r"""
+ A 3D decoder module.
+
+ Args:
+ dim (int): The base number of channels in the first layer.
+ z_dim (int): The dimensionality of the latent space.
+ dim_mult (list of int): Multipliers for the number of channels in each block.
+ num_res_blocks (int): Number of residual blocks in each block.
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
+ dropout (float): Dropout rate for the dropout layers.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
+
+ # init block
+ self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
+
+ # upsample blocks
+ self.up_blocks = nn.ModuleList([])
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i > 0:
+ in_dim = in_dim // 2
+
+ # Determine if we need upsampling
+ upsample_mode = None
+ if i != len(dim_mult) - 1:
+ upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
+
+ # Create and add the upsampling block
+ up_block = QwenImageUpBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ num_res_blocks=num_res_blocks,
+ dropout=dropout,
+ upsample_mode=upsample_mode,
+ non_linearity=non_linearity,
+ )
+ self.up_blocks.append(up_block)
+
+ # Update scale for next iteration
+ if upsample_mode is not None:
+ scale *= 2.0
+
+ # output blocks
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
+ self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_in(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_in(x)
+
+ ## middle
+ x = self.mid_block(x, feat_cache, feat_idx)
+
+ ## upsamples
+ for up_block in self.up_blocks:
+ x = up_block(x, feat_cache, feat_idx)
+
+ ## head
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_out(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_out(x)
+ return x
+
+
+class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = False
+
+ # fmt: off
+ @register_to_config
+ def __init__(
+ self,
+ base_dim: int = 96,
+ z_dim: int = 16,
+ dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
+ num_res_blocks: int = 2,
+ attn_scales: List[float] = [],
+ temperal_downsample: List[bool] = [False, True, True],
+ dropout: float = 0.0,
+ latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
+ latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
+ ) -> None:
+ # fmt: on
+ super().__init__()
+
+ self.z_dim = z_dim
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ self.encoder = QwenImageEncoder3d(
+ base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
+ )
+ self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
+
+ self.decoder = QwenImageDecoder3d(
+ base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
+ )
+
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
+ self._cached_conv_counts = {
+ "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
+ if self.decoder is not None
+ else 0,
+ "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
+ if self.encoder is not None
+ else 0,
+ }
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+
+ def clear_cache(self):
+ def _count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, QwenImageCausalConv3d):
+ count += 1
+ return count
+
+ self._conv_num = _count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ # cache encode
+ self._enc_conv_num = _count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+ def _encode(self, x: torch.Tensor):
+ _, _, num_frame, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ self.clear_cache()
+ iter_ = 1 + (num_frame - 1) // 4
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx,
+ )
+ out = torch.cat([out, out_], 2)
+
+ enc = self.quant_conv(out)
+ self.clear_cache()
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
+ _, _, num_frame, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ self.clear_cache()
+ x = self.post_quant_conv(z)
+ for i in range(num_frame):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+
+ out = torch.clamp(out, min=-1.0, max=1.0)
+ self.clear_cache()
+ if not return_dict:
+ return (out,)
+
+ return DecoderOutput(sample=out)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ _, _, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ self.clear_cache()
+ time = []
+ frame_range = 1 + (num_frames - 1) // 4
+ for k in range(frame_range):
+ self._enc_conv_idx = [0]
+ if k == 0:
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
+ else:
+ tile = x[
+ :,
+ :,
+ 1 + 4 * (k - 1) : 1 + 4 * k,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ tile = self.quant_conv(tile)
+ time.append(tile)
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+ self.clear_cache()
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return enc
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ _, _, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ self.clear_cache()
+ time = []
+ for k in range(num_frames):
+ self._conv_idx = [0]
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ time.append(decoded)
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+ self.clear_cache()
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
index 5a72cd395196..7a307b1eacd8 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
-from typing import Dict, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
-from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
+from ..attention import AttentionMixin
+from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttnProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
-from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
@@ -31,7 +32,7 @@ def __init__(
self,
in_channels: int = 4,
out_channels: int = 3,
- block_out_channels: Tuple[int] = (128, 256, 512, 512),
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
):
super().__init__()
@@ -135,7 +136,7 @@ def forward(
return sample
-class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
+class AutoencoderKLTemporalDecoder(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -158,11 +159,11 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
+ can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@@ -172,8 +173,8 @@ def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
- block_out_channels: Tuple[int] = (64,),
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
+ block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 1,
latent_channels: int = 4,
sample_size: int = 32,
@@ -202,66 +203,6 @@ def __init__(
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
index fafb1fe867e3..761dff2dc61a 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
@@ -17,7 +17,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
@@ -26,7 +25,7 @@
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -34,6 +33,103 @@
CACHE_T = 2
+class AvgDown3D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ factor_t,
+ factor_s=1,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.factor_t = factor_t
+ self.factor_s = factor_s
+ self.factor = self.factor_t * self.factor_s * self.factor_s
+
+ assert in_channels * self.factor % out_channels == 0
+ self.group_size = in_channels * self.factor // out_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
+ pad = (0, 0, 0, 0, pad_t, 0)
+ x = F.pad(x, pad)
+ B, C, T, H, W = x.shape
+ x = x.view(
+ B,
+ C,
+ T // self.factor_t,
+ self.factor_t,
+ H // self.factor_s,
+ self.factor_s,
+ W // self.factor_s,
+ self.factor_s,
+ )
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
+ x = x.view(
+ B,
+ C * self.factor,
+ T // self.factor_t,
+ H // self.factor_s,
+ W // self.factor_s,
+ )
+ x = x.view(
+ B,
+ self.out_channels,
+ self.group_size,
+ T // self.factor_t,
+ H // self.factor_s,
+ W // self.factor_s,
+ )
+ x = x.mean(dim=2)
+ return x
+
+
+class DupUp3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ factor_t,
+ factor_s=1,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.factor_t = factor_t
+ self.factor_s = factor_s
+ self.factor = self.factor_t * self.factor_s * self.factor_s
+
+ assert out_channels * self.factor % in_channels == 0
+ self.repeats = out_channels * self.factor // in_channels
+
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
+ x = x.repeat_interleave(self.repeats, dim=1)
+ x = x.view(
+ x.size(0),
+ self.out_channels,
+ self.factor_t,
+ self.factor_s,
+ self.factor_s,
+ x.size(2),
+ x.size(3),
+ x.size(4),
+ )
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
+ x = x.view(
+ x.size(0),
+ self.out_channels,
+ x.size(2) * self.factor_t,
+ x.size(4) * self.factor_s,
+ x.size(6) * self.factor_s,
+ )
+ if first_chunk:
+ x = x[:, :, self.factor_t - 1 :, :, :]
+ return x
+
+
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
@@ -134,19 +230,25 @@ class WanResample(nn.Module):
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
- def __init__(self, dim: int, mode: str) -> None:
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
super().__init__()
self.dim = dim
self.mode = mode
+ # default to dim //2
+ if upsample_out_dim is None:
+ upsample_out_dim = dim // 2
+
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
- WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
- WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@@ -351,18 +453,54 @@ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu",
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
- x = self.resnets[0](x, feat_cache, feat_idx)
+ x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
- x = resnet(x, feat_cache, feat_idx)
+ x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x
+class WanResidualDownBlock(nn.Module):
+ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
+ super().__init__()
+
+ # Shortcut path with downsample
+ self.avg_shortcut = AvgDown3D(
+ in_dim,
+ out_dim,
+ factor_t=2 if temperal_downsample else 1,
+ factor_s=2 if down_flag else 1,
+ )
+
+ # Main path with residual blocks and downsample
+ resnets = []
+ for _ in range(num_res_blocks):
+ resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
+ in_dim = out_dim
+ self.resnets = nn.ModuleList(resnets)
+
+ # Add the final downsample block
+ if down_flag:
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
+ self.downsampler = WanResample(out_dim, mode=mode)
+ else:
+ self.downsampler = None
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ x_copy = x.clone()
+ for resnet in self.resnets:
+ x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
+ if self.downsampler is not None:
+ x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
+
+ return x + self.avg_shortcut(x_copy)
+
+
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
@@ -380,6 +518,7 @@ class WanEncoder3d(nn.Module):
def __init__(
self,
+ in_channels: int = 3,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
@@ -388,6 +527,7 @@ def __init__(
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
+ is_residual: bool = False, # wan 2.2 vae use a residual downblock
):
super().__init__()
self.dim = dim
@@ -403,23 +543,35 @@ def __init__(
scale = 1.0
# init block
- self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
+ self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
- for _ in range(num_res_blocks):
- self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
- if scale in attn_scales:
- self.down_blocks.append(WanAttentionBlock(out_dim))
- in_dim = out_dim
-
- # downsample block
- if i != len(dim_mult) - 1:
- mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
- self.down_blocks.append(WanResample(out_dim, mode=mode))
- scale /= 2.0
+ if is_residual:
+ self.down_blocks.append(
+ WanResidualDownBlock(
+ in_dim,
+ out_dim,
+ dropout,
+ num_res_blocks,
+ temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
+ down_flag=i != len(dim_mult) - 1,
+ )
+ )
+ else:
+ for _ in range(num_res_blocks):
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ self.down_blocks.append(WanAttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
+ scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
@@ -446,12 +598,12 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
+ x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
## middle
- x = self.mid_block(x, feat_cache, feat_idx)
+ x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## head
x = self.norm_out(x)
@@ -467,6 +619,95 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
feat_idx[0] += 1
else:
x = self.conv_out(x)
+
+ return x
+
+
+class WanResidualUpBlock(nn.Module):
+ """
+ A block that handles upsampling for the WanVAE decoder.
+
+ Args:
+ in_dim (int): Input dimension
+ out_dim (int): Output dimension
+ num_res_blocks (int): Number of residual blocks
+ dropout (float): Dropout rate
+ temperal_upsample (bool): Whether to upsample on temporal dimension
+ up_flag (bool): Whether to upsample or not
+ non_linearity (str): Type of non-linearity to use
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ num_res_blocks: int,
+ dropout: float = 0.0,
+ temperal_upsample: bool = False,
+ up_flag: bool = False,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ if up_flag:
+ self.avg_shortcut = DupUp3D(
+ in_dim,
+ out_dim,
+ factor_t=2 if temperal_upsample else 1,
+ factor_s=2,
+ )
+ else:
+ self.avg_shortcut = None
+
+ # create residual blocks
+ resnets = []
+ current_dim = in_dim
+ for _ in range(num_res_blocks + 1):
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
+ current_dim = out_dim
+
+ self.resnets = nn.ModuleList(resnets)
+
+ # Add upsampling layer if needed
+ if up_flag:
+ upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
+ self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
+ else:
+ self.upsampler = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
+ """
+ Forward pass through the upsampling block.
+
+ Args:
+ x (torch.Tensor): Input tensor
+ feat_cache (list, optional): Feature cache for causal convolutions
+ feat_idx (list, optional): Feature index for cache management
+
+ Returns:
+ torch.Tensor: Output tensor
+ """
+ x_copy = x.clone()
+
+ for resnet in self.resnets:
+ if feat_cache is not None:
+ x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
+ else:
+ x = resnet(x)
+
+ if self.upsampler is not None:
+ if feat_cache is not None:
+ x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
+ else:
+ x = self.upsampler(x)
+
+ if self.avg_shortcut is not None:
+ x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
+
return x
@@ -513,7 +754,7 @@ def __init__(
self.gradient_checkpointing = False
- def forward(self, x, feat_cache=None, feat_idx=[0]):
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
"""
Forward pass through the upsampling block.
@@ -527,13 +768,13 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
"""
for resnet in self.resnets:
if feat_cache is not None:
- x = resnet(x, feat_cache, feat_idx)
+ x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
- x = self.upsamplers[0](x, feat_cache, feat_idx)
+ x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsamplers[0](x)
return x
@@ -564,6 +805,8 @@ def __init__(
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
+ out_channels: int = 3,
+ is_residual: bool = False,
):
super().__init__()
self.dim = dim
@@ -577,7 +820,6 @@ def __init__(
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
- scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
@@ -589,36 +831,47 @@ def __init__(
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
- if i > 0:
+ if i > 0 and not is_residual:
+ # wan vae 2.1
in_dim = in_dim // 2
- # Determine if we need upsampling
+ # determine if we need upsampling
+ up_flag = i != len(dim_mult) - 1
+ # determine upsampling mode, if not upsampling, set to None
upsample_mode = None
- if i != len(dim_mult) - 1:
- upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
-
+ if up_flag and temperal_upsample[i]:
+ upsample_mode = "upsample3d"
+ elif up_flag:
+ upsample_mode = "upsample2d"
# Create and add the upsampling block
- up_block = WanUpBlock(
- in_dim=in_dim,
- out_dim=out_dim,
- num_res_blocks=num_res_blocks,
- dropout=dropout,
- upsample_mode=upsample_mode,
- non_linearity=non_linearity,
- )
+ if is_residual:
+ up_block = WanResidualUpBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ num_res_blocks=num_res_blocks,
+ dropout=dropout,
+ temperal_upsample=temperal_upsample[i] if up_flag else False,
+ up_flag=up_flag,
+ non_linearity=non_linearity,
+ )
+ else:
+ up_block = WanUpBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ num_res_blocks=num_res_blocks,
+ dropout=dropout,
+ upsample_mode=upsample_mode,
+ non_linearity=non_linearity,
+ )
self.up_blocks.append(up_block)
- # Update scale for next iteration
- if upsample_mode is not None:
- scale *= 2.0
-
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
- self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
+ self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
self.gradient_checkpointing = False
- def forward(self, x, feat_cache=None, feat_idx=[0]):
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
@@ -633,11 +886,11 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
x = self.conv_in(x)
## middle
- x = self.mid_block(x, feat_cache, feat_idx)
+ x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## upsamples
for up_block in self.up_blocks:
- x = up_block(x, feat_cache, feat_idx)
+ x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -656,7 +909,50 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
return x
-class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+def patchify(x, patch_size):
+ if patch_size == 1:
+ return x
+
+ if x.dim() != 5:
+ raise ValueError(f"Invalid input shape: {x.shape}")
+ # x shape: [batch_size, channels, frames, height, width]
+ batch_size, channels, frames, height, width = x.shape
+
+ # Ensure height and width are divisible by patch_size
+ if height % patch_size != 0 or width % patch_size != 0:
+ raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
+
+ # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
+ x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
+
+ # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
+ x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
+ x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
+
+ return x
+
+
+def unpatchify(x, patch_size):
+ if patch_size == 1:
+ return x
+
+ if x.dim() != 5:
+ raise ValueError(f"Invalid input shape: {x.shape}")
+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
+ batch_size, c_patches, frames, height, width = x.shape
+ channels = c_patches // (patch_size * patch_size)
+
+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
+ x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
+
+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
+ x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
+ x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
+
+ return x
+
+
+class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
@@ -666,13 +962,18 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
_supports_gradient_checkpointing = False
+ _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"]
+ # keys toignore when AlignDeviceHook moves inputs/outputs between devices
+ # these are shared mutable state modified in-place
+ _skip_keys = ["feat_cache", "feat_idx"]
@register_to_config
def __init__(
self,
base_dim: int = 96,
+ decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
- dim_mult: Tuple[int] = [1, 2, 4, 4],
+ dim_mult: List[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
@@ -713,6 +1014,12 @@ def __init__(
2.8251,
1.9160,
],
+ is_residual: bool = False,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ patch_size: Optional[int] = None,
+ scale_factor_temporal: Optional[int] = 4,
+ scale_factor_spatial: Optional[int] = 8,
) -> None:
super().__init__()
@@ -720,37 +1027,115 @@ def __init__(
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
+ if decoder_base_dim is None:
+ decoder_base_dim = base_dim
+
self.encoder = WanEncoder3d(
- base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
+ in_channels=in_channels,
+ dim=base_dim,
+ z_dim=z_dim * 2,
+ dim_mult=dim_mult,
+ num_res_blocks=num_res_blocks,
+ attn_scales=attn_scales,
+ temperal_downsample=temperal_downsample,
+ dropout=dropout,
+ is_residual=is_residual,
)
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
self.decoder = WanDecoder3d(
- base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
+ dim=decoder_base_dim,
+ z_dim=z_dim,
+ dim_mult=dim_mult,
+ num_res_blocks=num_res_blocks,
+ attn_scales=attn_scales,
+ temperal_upsample=self.temperal_upsample,
+ dropout=dropout,
+ out_channels=out_channels,
+ is_residual=is_residual,
)
+ self.spatial_compression_ratio = scale_factor_spatial
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
+ self._cached_conv_counts = {
+ "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
+ if self.decoder is not None
+ else 0,
+ "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
+ if self.encoder is not None
+ else 0,
+ }
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+
def clear_cache(self):
- def _count_conv3d(model):
- count = 0
- for m in model.modules():
- if isinstance(m, WanCausalConv3d):
- count += 1
- return count
-
- self._conv_num = _count_conv3d(self.decoder)
+ # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
+ self._conv_num = self._cached_conv_counts["decoder"]
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
- self._enc_conv_num = _count_conv3d(self.encoder)
+ self._enc_conv_num = self._cached_conv_counts["encoder"]
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ def _encode(self, x: torch.Tensor):
+ _, _, num_frame, height, width = x.shape
+
self.clear_cache()
- ## cache
- t = x.shape[2]
- iter_ = 1 + (t - 1) // 4
+ if self.config.patch_size is not None:
+ x = patchify(x, patch_size=self.config.patch_size)
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
@@ -764,8 +1149,6 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
- mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
- enc = torch.cat([mu, logvar], dim=1)
self.clear_cache()
return enc
@@ -785,26 +1168,42 @@ def encode(
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
- h = self._encode(x)
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
+
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
- self.clear_cache()
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
+ _, _, num_frame, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
- iter_ = z.shape[2]
+ self.clear_cache()
x = self.post_quant_conv(z)
- for i in range(iter_):
+ for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
- out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ out = self.decoder(
+ x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
+ )
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
+ if self.config.patch_size is not None:
+ out = unpatchify(out, patch_size=self.config.patch_size)
+
out = torch.clamp(out, min=-1.0, max=1.0)
+
self.clear_cache()
if not return_dict:
return (out,)
@@ -826,12 +1225,182 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
- decoded = self._decode(z).sample
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
if not return_dict:
return (decoded,)
-
return DecoderOutput(sample=decoded)
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+
+ _, _, num_frames, height, width = x.shape
+ encode_spatial_compression_ratio = self.spatial_compression_ratio
+ if self.config.patch_size is not None:
+ assert encode_spatial_compression_ratio % self.config.patch_size == 0
+ encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size
+
+ latent_height = height // encode_spatial_compression_ratio
+ latent_width = width // encode_spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // encode_spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // encode_spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // encode_spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // encode_spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ self.clear_cache()
+ time = []
+ frame_range = 1 + (num_frames - 1) // 4
+ for k in range(frame_range):
+ self._enc_conv_idx = [0]
+ if k == 0:
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
+ else:
+ tile = x[
+ :,
+ :,
+ 1 + 4 * (k - 1) : 1 + 4 * k,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ tile = self.quant_conv(tile)
+ time.append(tile)
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+ self.clear_cache()
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return enc
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ _, _, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+ tile_sample_stride_height = self.tile_sample_stride_height
+ tile_sample_stride_width = self.tile_sample_stride_width
+ if self.config.patch_size is not None:
+ sample_height = sample_height // self.config.patch_size
+ sample_width = sample_width // self.config.patch_size
+ tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
+ tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
+ blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
+ blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
+ else:
+ blend_height = self.tile_sample_min_height - tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ self.clear_cache()
+ time = []
+ for k in range(num_frames):
+ self._conv_idx = [0]
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(
+ tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
+ )
+ time.append(decoded)
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+ self.clear_cache()
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if self.config.patch_size is not None:
+ dec = unpatchify(dec, patch_size=self.config.patch_size)
+
+ dec = torch.clamp(dec, min=-1.0, max=1.0)
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
def forward(
self,
sample: torch.Tensor,
@@ -847,6 +1416,7 @@ def forward(
"""
x = sample
posterior = self.encode(x).latent_dist
+
if sample_posterior:
z = posterior.sample(generator=generator)
else:
diff --git a/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/src/diffusers/models/autoencoders/autoencoder_oobleck.py
index a8c2a2fd3840..d83264559209 100644
--- a/src/diffusers/models/autoencoders/autoencoder_oobleck.py
+++ b/src/diffusers/models/autoencoders/autoencoder_oobleck.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin
class Snake1d(nn.Module):
@@ -291,7 +292,7 @@ def forward(self, hidden_state):
return hidden_state
-class AutoencoderOobleck(ModelMixin, ConfigMixin):
+class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
@@ -356,20 +357,6 @@ def __init__(
self.use_slicing = False
- def enable_slicing(self):
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self):
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
diff --git a/src/diffusers/models/autoencoders/autoencoder_tiny.py b/src/diffusers/models/autoencoders/autoencoder_tiny.py
index 7ed727c55c37..b9ac713d7392 100644
--- a/src/diffusers/models/autoencoders/autoencoder_tiny.py
+++ b/src/diffusers/models/autoencoders/autoencoder_tiny.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DecoderTiny, EncoderTiny
+from .vae import AutoencoderMixin, DecoderOutput, DecoderTiny, EncoderTiny
@dataclass
@@ -38,7 +38,7 @@ class AutoencoderTinyOutput(BaseOutput):
latents: torch.Tensor
-class AutoencoderTiny(ModelMixin, ConfigMixin):
+class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
@@ -83,8 +83,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
- however, no such scaling factor was used, hence the value of 1.0 as the default.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. For this
+ Autoencoder, however, no such scaling factor was used, hence the value of 1.0 as the default.
force_upcast (`bool`, *optional*, default to `False`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without losing too much precision, in which case
@@ -162,35 +162,6 @@ def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
- def enable_tiling(self, use_tiling: bool = True) -> None:
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.use_tiling = use_tiling
-
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.enable_tiling(False)
-
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py
index a0b3309dc522..db9404f4ac70 100644
--- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py
+++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
@@ -23,16 +23,16 @@
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from ..modeling_utils import ModelMixin
from ..unets.unet_2d import UNet2DModel
-from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
@@ -49,7 +49,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution"
-class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
+class ConsistencyDecoderVAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
r"""
The consistency decoder used with DALL-E 3.
@@ -167,99 +167,6 @@ def __init__(
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
- def enable_tiling(self, use_tiling: bool = True):
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.use_tiling = use_tiling
-
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
- def disable_tiling(self):
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.enable_tiling(False)
-
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
- def enable_slicing(self):
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
- def disable_slicing(self):
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py
index 72e0acda3afe..9c6031a988f9 100644
--- a/src/diffusers/models/autoencoders/vae.py
+++ b/src/diffusers/models/autoencoders/vae.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -255,7 +255,7 @@ def __init__(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
- prev_output_channel=None,
+ prev_output_channel=prev_output_channel,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
@@ -286,11 +286,9 @@ def forward(
sample = self.conv_in(sample)
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
- sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -298,7 +296,6 @@ def forward(
else:
# middle
sample = self.mid_block(sample, latent_embeds)
- sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -744,6 +741,17 @@ def mode(self) -> torch.Tensor:
return self.mean
+class IdentityDistribution(object):
+ def __init__(self, parameters: torch.Tensor):
+ self.parameters = parameters
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+ return self.parameters
+
+ def mode(self) -> torch.Tensor:
+ return self.parameters
+
+
class EncoderTiny(nn.Module):
r"""
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
@@ -883,3 +891,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)
+
+
+class AutoencoderMixin:
+ def enable_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ if not hasattr(self, "use_tiling"):
+ raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.")
+ self.use_tiling = True
+
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ if not hasattr(self, "use_slicing"):
+ raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.")
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py
index 84215389bf6a..82436473dfc6 100644
--- a/src/diffusers/models/autoencoders/vq_model.py
+++ b/src/diffusers/models/autoencoders/vq_model.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
from ...utils.accelerate_utils import apply_forward_hook
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin
@dataclass
@@ -37,7 +38,7 @@ class VQEncoderOutput(BaseOutput):
latents: torch.Tensor
-class VQModel(ModelMixin, ConfigMixin):
+class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VQ-VAE model for decoding latent representations.
@@ -66,7 +67,7 @@ class VQModel(ModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
norm_type (`str`, *optional*, defaults to `"group"`):
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
"""
diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py
index 79bd8dc0b254..f4ad1af278f5 100644
--- a/src/diffusers/models/cache_utils.py
+++ b/src/diffusers/models/cache_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from contextlib import contextmanager
+
from ..utils.logging import get_logger
@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
+ - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -62,9 +65,13 @@ def enable_cache(self, config) -> None:
from ..hooks import (
FasterCacheConfig,
+ FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
+ TaylorSeerCacheConfig,
apply_faster_cache,
+ apply_first_block_cache,
apply_pyramid_attention_broadcast,
+ apply_taylorseer_cache,
)
if self.is_cache_enabled:
@@ -72,31 +79,47 @@ def enable_cache(self, config) -> None:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
- if isinstance(config, PyramidAttentionBroadcastConfig):
- apply_pyramid_attention_broadcast(self, config)
- elif isinstance(config, FasterCacheConfig):
+ if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
+ elif isinstance(config, FirstBlockCacheConfig):
+ apply_first_block_cache(self, config)
+ elif isinstance(config, PyramidAttentionBroadcastConfig):
+ apply_pyramid_attention_broadcast(self, config)
+ elif isinstance(config, TaylorSeerCacheConfig):
+ apply_taylorseer_cache(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
- from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
+ from ..hooks import (
+ FasterCacheConfig,
+ FirstBlockCacheConfig,
+ HookRegistry,
+ PyramidAttentionBroadcastConfig,
+ TaylorSeerCacheConfig,
+ )
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
+ from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
+ from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
- if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
- registry = HookRegistry.check_if_exists_or_initialize(self)
- registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
- elif isinstance(self._cache_config, FasterCacheConfig):
- registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
+ elif isinstance(self._cache_config, FirstBlockCacheConfig):
+ registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
+ registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
+ elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
+ registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
+ elif isinstance(self._cache_config, TaylorSeerCacheConfig):
+ registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +129,15 @@ def _reset_stateful_cache(self, recurse: bool = True) -> None:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
+
+ @contextmanager
+ def cache_context(self, name: str):
+ r"""Context manager that provides additional methods for cache management."""
+ from ..hooks import HookRegistry
+
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry._set_context(name)
+
+ yield
+
+ registry._set_context(None)
diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py
index b9ebab818be7..c18bd8751dcb 100644
--- a/src/diffusers/models/controlnet.py
+++ b/src/diffusers/models/controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/controlnet_flux.py b/src/diffusers/models/controlnet_flux.py
index 2035deb1062d..e82748436d86 100644
--- a/src/diffusers/models/controlnet_flux.py
+++ b/src/diffusers/models/controlnet_flux.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py
index 0f7246c6c6d4..d239ad4eb3e8 100644
--- a/src/diffusers/models/controlnet_sd3.py
+++ b/src/diffusers/models/controlnet_sd3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py
index 8fdaa21bef11..5c67af4fe9c1 100644
--- a/src/diffusers/models/controlnet_sparsectrl.py
+++ b/src/diffusers/models/controlnet_sparsectrl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py
index 621de4329868..2fb18348afdd 100644
--- a/src/diffusers/models/controlnets/__init__.py
+++ b/src/diffusers/models/controlnets/__init__.py
@@ -12,12 +12,9 @@
HunyuanDiT2DControlNetModel,
HunyuanDiT2DMultiControlNetModel,
)
+ from .controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
from .controlnet_sana import SanaControlNetModel
- from .controlnet_sd3 import (
- SD3ControlNetModel,
- SD3ControlNetOutput,
- SD3MultiControlNetModel,
- )
+ from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
from .controlnet_sparsectrl import (
SparseControlNetConditioningEmbedding,
SparseControlNetModel,
diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py
index 7a6ca886caed..5c89c9267db4 100644
--- a/src/diffusers/models/controlnets/controlnet.py
+++ b/src/diffusers/models/controlnets/controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,10 +21,10 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, logging
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
@@ -63,8 +63,8 @@ class ControlNetOutput(BaseOutput):
class ControlNetConditioningEmbedding(nn.Module):
"""
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
+ Quoting from https://huggingface.co/papers/2302.05543: "Stable Diffusion uses a pre-processing method similar to
+ VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
@@ -106,7 +106,7 @@ def forward(self, conditioning):
return embedding
-class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNet model.
@@ -515,66 +515,6 @@ def from_unet(
return controlnet
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/models/controlnets/controlnet_flax.py b/src/diffusers/models/controlnets/controlnet_flax.py
index ab8d9b5f8cbb..f7a8b98fa2f0 100644
--- a/src/diffusers/models/controlnets/controlnet_flax.py
+++ b/src/diffusers/models/controlnets/controlnet_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
-from ...utils import BaseOutput
+from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from ..unets.unet_2d_blocks_flax import (
@@ -30,6 +30,9 @@
)
+logger = logging.get_logger(__name__)
+
+
@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
"""
@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
@@ -184,6 +192,11 @@ def init_weights(self, rng: jax.Array) -> FrozenDict:
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
def setup(self) -> None:
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py
index 51c34b7fe965..639a8ad7390a 100644
--- a/src/diffusers/models/controlnets/controlnet_flux.py
+++ b/src/diffusers/models/controlnets/controlnet_flux.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,12 +20,12 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
-from ...models.attention_processor import AttentionProcessor
-from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
@@ -38,7 +38,7 @@ class FluxControlNetOutput(BaseOutput):
controlnet_single_block_samples: Tuple[torch.Tensor]
-class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
@@ -118,66 +118,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self):
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
@classmethod
def from_transformer(
cls,
@@ -298,15 +238,6 @@ def forward(
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
- if self.union:
- # union mode
- if controlnet_mode is None:
- raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
- # union mode emb
- controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
- encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
- txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
-
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
@@ -320,6 +251,15 @@ def forward(
)
img_ids = img_ids[0]
+ if self.union:
+ # union mode
+ if controlnet_mode is None:
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
+ # union mode emb
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
+
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
@@ -343,25 +283,25 @@ def forward(
)
block_samples = block_samples + (hidden_states,)
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
+ encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
- hidden_states = block(
+ encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
+ single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()
@@ -430,7 +370,7 @@ def forward(
) -> Union[FluxControlNetOutput, Tuple]:
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
- if len(self.nets) == 1 and self.nets[0].union:
+ if len(self.nets) == 1:
controlnet = self.nets[0]
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
@@ -454,17 +394,18 @@ def forward(
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
- control_block_samples = [
- control_block_sample + block_sample
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
- ]
-
- control_single_block_samples = [
- control_single_block_sample + block_sample
- for control_single_block_sample, block_sample in zip(
- control_single_block_samples, single_block_samples
- )
- ]
+ if block_samples is not None and control_block_samples is not None:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+ if single_block_samples is not None and control_single_block_samples is not None:
+ control_single_block_samples = [
+ control_single_block_sample + block_sample
+ for control_single_block_sample, block_sample in zip(
+ control_single_block_samples, single_block_samples
+ )
+ ]
# Regular Multi-ControlNets
# load all ControlNets into memories
diff --git a/src/diffusers/models/controlnets/controlnet_hunyuan.py b/src/diffusers/models/controlnets/controlnet_hunyuan.py
index fade44def4cd..d17d5692aa40 100644
--- a/src/diffusers/models/controlnets/controlnet_hunyuan.py
+++ b/src/diffusers/models/controlnets/controlnet_hunyuan.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
+# Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -103,7 +103,7 @@ def __init__(
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
- qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
+ qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
skip=False, # always False as it is the first half of the model
)
for layer in range(transformer_num_layers // 2 - 1)
diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py
new file mode 100644
index 000000000000..86971271788f
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py
@@ -0,0 +1,301 @@
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin
+from ..cache_utils import CacheMixin
+from ..controlnets.controlnet import zero_module
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..transformers.transformer_qwenimage import (
+ QwenEmbedRope,
+ QwenImageTransformerBlock,
+ QwenTimestepProjEmbeddings,
+ RMSNorm,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class QwenImageControlNetOutput(BaseOutput):
+ controlnet_block_samples: Tuple[torch.Tensor]
+
+
+class QwenImageControlNetModel(
+ ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
+):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 64,
+ out_channels: Optional[int] = 16,
+ num_layers: int = 60,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 3584,
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
+ extra_condition_channels: int = 0, # for controlnet-inpainting
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
+
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
+
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
+
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ QwenImageTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # controlnet_blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
+ self.controlnet_x_embedder = zero_module(
+ torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim)
+ )
+
+ self.gradient_checkpointing = False
+
+ @classmethod
+ def from_transformer(
+ cls,
+ transformer,
+ num_layers: int = 5,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ load_weights_from_transformer=True,
+ extra_condition_channels: int = 0,
+ ):
+ config = dict(transformer.config)
+ config["num_layers"] = num_layers
+ config["attention_head_dim"] = attention_head_dim
+ config["num_attention_heads"] = num_attention_heads
+ config["extra_condition_channels"] = extra_condition_channels
+
+ controlnet = cls.from_config(config)
+
+ if load_weights_from_transformer:
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
+ controlnet.img_in.load_state_dict(transformer.img_in.state_dict())
+ controlnet.txt_in.load_state_dict(transformer.txt_in.state_dict())
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
+
+ return controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ encoder_hidden_states: torch.Tensor = None,
+ encoder_hidden_states_mask: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
+ txt_seq_lens: Optional[List[int]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.img_in(hidden_states)
+
+ # add
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
+
+ temb = self.time_text_embed(timestep, hidden_states)
+
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
+
+ timestep = timestep.to(hidden_states.dtype)
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ block_samples = ()
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ encoder_hidden_states_mask,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+ block_samples = block_samples + (hidden_states,)
+
+ # controlnet block
+ controlnet_block_samples = ()
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
+ block_sample = controlnet_block(block_sample)
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
+
+ # scaling
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return controlnet_block_samples
+
+ return QwenImageControlNetOutput(
+ controlnet_block_samples=controlnet_block_samples,
+ )
+
+
+class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ `QwenImageMultiControlNetModel` wrapper class for Multi-QwenImageControlNetModel
+
+ This module is a wrapper for multiple instances of the `QwenImageControlNetModel`. The `forward()` API is designed
+ to be compatible with `QwenImageControlNetModel`.
+
+ Args:
+ controlnets (`List[QwenImageControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `QwenImageControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ encoder_hidden_states: torch.Tensor = None,
+ encoder_hidden_states_mask: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
+ txt_seq_lens: Optional[List[int]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[QwenImageControlNetOutput, Tuple]:
+ # ControlNet-Union with multiple conditions
+ # only load one ControlNet for saving memories
+ if len(self.nets) == 1:
+ controlnet = self.nets[0]
+
+ for i, (image, scale) in enumerate(zip(controlnet_cond, conditioning_scale)):
+ block_samples = controlnet(
+ hidden_states=hidden_states,
+ controlnet_cond=image,
+ conditioning_scale=scale,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ timestep=timestep,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ else:
+ if block_samples is not None and control_block_samples is not None:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+ else:
+ raise ValueError("QwenImageMultiControlNetModel only supports a single controlnet-union now.")
+
+ return control_block_samples
diff --git a/src/diffusers/models/controlnets/controlnet_sana.py b/src/diffusers/models/controlnets/controlnet_sana.py
index ed521adbedda..c71a8b326635 100644
--- a/src/diffusers/models/controlnets/controlnet_sana.py
+++ b/src/diffusers/models/controlnets/controlnet_sana.py
@@ -21,7 +21,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
-from ..attention_processor import AttentionProcessor
+from ..attention import AttentionMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -38,7 +38,7 @@ class SanaControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
-class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
_skip_layerwise_casting_patterns = ["patch_embed", "norm"]
@@ -117,66 +117,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py
index 91ce76fe75a9..08b86ff344eb 100644
--- a/src/diffusers/models/controlnets/controlnet_sd3.py
+++ b/src/diffusers/models/controlnets/controlnet_sd3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,8 +22,8 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ..attention import JointTransformerBlock
-from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
+from ..attention import AttentionMixin, JointTransformerBlock
+from ..attention_processor import Attention, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -39,7 +39,7 @@ class SD3ControlNetOutput(BaseOutput):
controlnet_block_samples: Tuple[torch.Tensor]
-class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
@@ -204,77 +204,13 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.transformers.transformer_sd3.SD3Transformer2DModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -294,11 +230,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/controlnets/controlnet_sparsectrl.py b/src/diffusers/models/controlnets/controlnet_sparsectrl.py
index 25348ce606d6..8e7faf2d44b0 100644
--- a/src/diffusers/models/controlnets/controlnet_sparsectrl.py
+++ b/src/diffusers/models/controlnets/controlnet_sparsectrl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,10 +22,10 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import BaseOutput, logging
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
@@ -93,10 +93,10 @@ def forward(self, conditioning: torch.Tensor) -> torch.Tensor:
return embedding
-class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class SparseControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin):
"""
A SparseControlNet model as described in [SparseCtrl: Adding Sparse Controls to Text-to-Video Diffusion
- Models](https://arxiv.org/abs/2311.16933).
+ Models](https://huggingface.co/papers/2311.16933).
Args:
in_channels (`int`, defaults to 4):
@@ -448,66 +448,6 @@ def from_unet(
return controlnet
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py
index 26cb86718a21..b4ee6536ca2f 100644
--- a/src/diffusers/models/controlnets/controlnet_union.py
+++ b/src/diffusers/models/controlnets/controlnet_union.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,10 +19,10 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
@@ -81,7 +81,7 @@ def forward(self, x: torch.Tensor):
return x
-class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class ControlNetUnionModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin):
"""
A ControlNetUnion model.
@@ -455,66 +455,6 @@ def from_unet(
return controlnet
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
@@ -752,7 +692,7 @@ def forward(
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx]
- if from_multi:
+ if from_multi or len(control_type_idx) == 1:
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
@@ -772,7 +712,7 @@ def forward(
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
- if from_multi:
+ if from_multi or len(control_type_idx) == 1:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +759,11 @@ def forward(
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
- if from_multi:
+ if from_multi or len(control_type_idx) == 1:
scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
- elif from_multi:
+ elif from_multi or len(control_type_idx) == 1:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py
index 608be6b70277..119492b0fac4 100644
--- a/src/diffusers/models/controlnets/controlnet_xs.py
+++ b/src/diffusers/models/controlnets/controlnet_xs.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,17 +16,16 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import Tensor, nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ...utils.torch_utils import apply_freeu
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
@@ -242,7 +241,7 @@ def get_up_block_adapter(
return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base))
-class ControlNetXSAdapter(ModelMixin, ConfigMixin):
+class ControlNetXSAdapter(ModelMixin, AttentionMixin, ConfigMixin):
r"""
A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a
`UNet2DConditionModel` base model).
@@ -294,14 +293,14 @@ def __init__(
self,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
time_embedding_mix: float = 1.0,
learn_time_embedding: bool = False,
num_attention_heads: Union[int, Tuple[int]] = 4,
- block_out_channels: Tuple[int] = (4, 8, 16, 16),
- base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
+ base_block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
cross_attention_dim: int = 1024,
- down_block_types: Tuple[str] = (
+ down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
@@ -437,7 +436,7 @@ def from_unet(
time_embedding_mix: int = 1.0,
conditioning_channels: int = 3,
conditioning_channel_order: str = "rgb",
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
):
r"""
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
@@ -509,7 +508,7 @@ def forward(self, *args, **kwargs):
)
-class UNetControlNetXSModel(ModelMixin, ConfigMixin):
+class UNetControlNetXSModel(ModelMixin, AttentionMixin, ConfigMixin):
r"""
A UNet fused with a ControlNet-XS adapter model
@@ -530,14 +529,19 @@ def __init__(
self,
# unet configs
sample_size: Optional[int] = 96,
- down_block_types: Tuple[str] = (
+ down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ up_block_types: Tuple[str, ...] = (
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ ),
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
norm_num_groups: Optional[int] = 32,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -551,10 +555,10 @@ def __init__(
# additional controlnet configs
time_embedding_mix: float = 1.0,
ctrl_conditioning_channels: int = 3,
- ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ ctrl_conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
ctrl_conditioning_channel_order: str = "rgb",
ctrl_learn_time_embedding: bool = False,
- ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
+ ctrl_block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
ctrl_max_norm_num_groups: int = 32,
):
@@ -734,17 +738,17 @@ def from_unet(
unet (`UNet2DConditionModel`):
The UNet model we want to control.
controlnet (`ControlNetXSAdapter`):
- The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS
+ The ControlNet-XS adapter with which the UNet will be fused. If none is given, a new ControlNet-XS
adapter will be created.
size_ratio (float, *optional*, defaults to `None`):
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
where this parameter is called `block_out_channels`.
time_embedding_mix (`float`, *optional*, defaults to None):
- Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
+ Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
- Passed to the `init` of the new controlent if no controlent was given.
+ Passed to the `init` of the new controlnet if no controlnet was given.
"""
if controlnet is None:
controlnet = ControlNetXSAdapter.from_unet(
@@ -864,66 +868,6 @@ def freeze_unet_params(self) -> None:
for u in self.up_blocks:
u.freeze_base_params()
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
@@ -942,7 +886,7 @@ def set_default_attn_processor(self):
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+ r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -980,11 +924,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1004,11 +944,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/controlnets/multicontrolnet.py b/src/diffusers/models/controlnets/multicontrolnet.py
index 44bfcc1b82a9..87a952294997 100644
--- a/src/diffusers/models/controlnets/multicontrolnet.py
+++ b/src/diffusers/models/controlnets/multicontrolnet.py
@@ -4,9 +4,9 @@
import torch
from torch import nn
-from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
-from ...models.modeling_utils import ModelMixin
from ...utils import logging
+from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
+from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__)
@@ -130,9 +130,8 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]
A path to a *directory* containing model weights saved using
[`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
`./my_model_directory/controlnet`.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
- will be automatically derived from the model's weights.
+ torch_dtype (`torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
diff --git a/src/diffusers/models/controlnets/multicontrolnet_union.py b/src/diffusers/models/controlnets/multicontrolnet_union.py
index 427e05b19110..d5506dc186e3 100644
--- a/src/diffusers/models/controlnets/multicontrolnet_union.py
+++ b/src/diffusers/models/controlnets/multicontrolnet_union.py
@@ -4,10 +4,10 @@
import torch
from torch import nn
-from ...models.controlnets.controlnet import ControlNetOutput
-from ...models.controlnets.controlnet_union import ControlNetUnionModel
-from ...models.modeling_utils import ModelMixin
from ...utils import logging
+from ..controlnets.controlnet import ControlNetOutput
+from ..controlnets.controlnet_union import ControlNetUnionModel
+from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__)
@@ -143,9 +143,8 @@ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]
A path to a *directory* containing model weights saved using
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
`./my_model_directory/controlnet`.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
- will be automatically derived from the model's weights.
+ torch_dtype (`torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py
index 3ac8953e3dcc..505816422b2a 100644
--- a/src/diffusers/models/downsampling.py
+++ b/src/diffusers/models/downsampling.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -286,7 +286,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
class CogVideoXDownsample3D(nn.Module):
- # Todo: Wait for paper relase.
+ # Todo: Wait for paper release.
r"""
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index b1e14ca6a7fe..37fc412adcc3 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
-):
+) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -97,7 +97,7 @@ def get_3d_sincos_pos_embed(
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
spatial dimensions (height and width).
temporal_size (`int`):
- The temporal dimension of postional embeddings (number of frames).
+ The temporal dimension of positional embeddings (number of frames).
spatial_interpolation_scale (`float`, defaults to 1.0):
Scale factor for spatial grid interpolation.
temporal_interpolation_scale (`float`, defaults to 1.0):
@@ -169,7 +169,7 @@ def _get_3d_sincos_pos_embed_np(
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
spatial dimensions (height and width).
temporal_size (`int`):
- The temporal dimension of postional embeddings (number of frames).
+ The temporal dimension of positional embeddings (number of frames).
spatial_interpolation_scale (`float`, defaults to 1.0):
Scale factor for spatial grid interpolation.
temporal_interpolation_scale (`float`, defaults to 1.0):
@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
return emb
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
"""
This function generates 1D positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension `D`
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
+ output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
+ dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
+ `torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
- omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
+ # Auto-detect appropriate dtype if not specified
+ if dtype is None:
+ dtype = torch.float32 if pos.device.type == "mps" else torch.float64
+
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
@@ -352,6 +360,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)
+
return emb
@@ -1149,9 +1162,7 @@ def get_1d_rotary_pos_embed(
theta = theta * ntk_factor
freqs = (
- 1.0
- / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
- / linear_factor
+ 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
is_npu = freqs.device.type == "npu"
@@ -1178,6 +1189,7 @@ def apply_rotary_emb(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
+ sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -1195,17 +1207,24 @@ def apply_rotary_emb(
"""
if use_real:
cos, sin = freqs_cis # [S, D]
- cos = cos[None, None]
- sin = sin[None, None]
+ if sequence_dim == 2:
+ cos = cos[None, None, :, :]
+ sin = sin[None, None, :, :]
+ elif sequence_dim == 1:
+ cos = cos[None, :, None, :]
+ sin = sin[None, :, None, :]
+ else:
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
+
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
- # Used for Stable Audio, OmniGen and CogView4
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
@@ -1240,37 +1259,6 @@ def apply_1d_rope(tokens, pos, cos, sin):
return x
-class FluxPosEmbed(nn.Module):
- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
- def __init__(self, theta: int, axes_dim: List[int]):
- super().__init__()
- self.theta = theta
- self.axes_dim = axes_dim
-
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
- n_axes = ids.shape[-1]
- cos_out = []
- sin_out = []
- pos = ids.float()
- is_mps = ids.device.type == "mps"
- is_npu = ids.device.type == "npu"
- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
- for i in range(n_axes):
- cos, sin = get_1d_rotary_pos_embed(
- self.axes_dim[i],
- pos[:, i],
- theta=self.theta,
- repeat_interleave_real=True,
- use_real=True,
- freqs_dtype=freqs_dtype,
- )
- cos_out.append(cos)
- sin_out.append(sin)
- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
- return freqs_cos, freqs_sin
-
-
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -1327,7 +1315,7 @@ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shif
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
- def forward(self, timesteps):
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
@@ -1401,7 +1389,7 @@ class ImagePositionalEmbeddings(nn.Module):
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
- For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
+ For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092
For VQ-diffusion:
@@ -2621,3 +2609,13 @@ def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds.append(image_embed)
return projected_image_embeds
+
+
+class FluxPosEmbed(nn.Module):
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
+ deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxPosEmbed
+
+ return FluxPosEmbed(*args, **kwargs)
diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py
index 92b5a6c35883..3790905e583c 100644
--- a/src/diffusers/models/embeddings_flax.py
+++ b/src/diffusers/models/embeddings_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +16,11 @@
import flax.linen as nn
import jax.numpy as jnp
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
def get_sinusoidal_embeddings(
timesteps: jnp.ndarray,
@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
The data type for the embedding parameters.
"""
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@@ -89,7 +99,7 @@ def __call__(self, temb):
class FlaxTimesteps(nn.Module):
r"""
- Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239
+ Wrapper Module for sinusoidal Time step Embeddings as described in https://huggingface.co/papers/2006.11239
Args:
dim (`int`, *optional*, defaults to `32`):
@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
flip_sin_to_cos: bool = False
freq_shift: float = 1
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(
diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py
index 4e9e0c07ca75..85d61d6d7cdf 100644
--- a/src/diffusers/models/lora.py
+++ b/src/diffusers/models/lora.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,7 +38,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-def text_encoder_attn_modules(text_encoder):
+def text_encoder_attn_modules(text_encoder: nn.Module):
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules
-def text_encoder_mlp_modules(text_encoder):
+def text_encoder_mlp_modules(text_encoder: nn.Module):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index 741f7075d76d..8b48ba6b4873 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
import importlib
import inspect
import os
from array import array
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
+from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -30,6 +32,7 @@
from ..quantizers import DiffusersQuantizer
from ..utils import (
+ DEFAULT_HF_PARALLEL_LOADING_WORKERS,
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
@@ -38,6 +41,7 @@
_get_model_file,
deprecate,
is_accelerate_available,
+ is_accelerate_version,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -108,9 +112,6 @@ def _determine_device_map(
device_map_kwargs["max_memory"] = max_memory
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
- if hf_quantizer is not None:
- hf_quantizer.validate_environment(device_map=device_map)
-
return device_map
@@ -205,7 +206,7 @@ def load_state_dict(
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
)
@@ -252,6 +253,10 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
+ if is_accelerate_version(">", "1.8.1"):
+ set_module_kwargs["non_blocking"] = True
+ set_module_kwargs["clear_cache"] = False
+
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -304,6 +309,161 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index
+def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
+ """
+ Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
+ checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
+ parameters.
+
+ """
+ if model_to_load.device.type == "meta":
+ return False
+
+ if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
+ return False
+
+ # Some models explicitly do not support param buffer assignment
+ if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
+ logger.debug(
+ f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
+ )
+ return False
+
+ # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
+ first_key = next(iter(model_to_load.state_dict().keys()))
+ if start_prefix + first_key in state_dict:
+ return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
+
+ return False
+
+
+def _load_shard_file(
+ shard_file,
+ model,
+ model_state_dict,
+ device_map=None,
+ dtype=None,
+ hf_quantizer=None,
+ keep_in_fp32_modules=None,
+ dduf_entries=None,
+ loaded_keys=None,
+ unexpected_keys=None,
+ offload_index=None,
+ offload_folder=None,
+ state_dict_index=None,
+ state_dict_folder=None,
+ ignore_mismatched_sizes=False,
+ low_cpu_mem_usage=False,
+):
+ state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = []
+ if low_cpu_mem_usage:
+ offload_index, state_dict_index = load_model_dict_into_meta(
+ model,
+ state_dict,
+ device_map=device_map,
+ dtype=dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ unexpected_keys=unexpected_keys,
+ offload_folder=offload_folder,
+ offload_index=offload_index,
+ state_dict_index=state_dict_index,
+ state_dict_folder=state_dict_folder,
+ )
+ else:
+ assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
+
+ error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
+ return offload_index, state_dict_index, mismatched_keys, error_msgs
+
+
+def _load_shard_files_with_threadpool(
+ shard_files,
+ model,
+ model_state_dict,
+ device_map=None,
+ dtype=None,
+ hf_quantizer=None,
+ keep_in_fp32_modules=None,
+ dduf_entries=None,
+ loaded_keys=None,
+ unexpected_keys=None,
+ offload_index=None,
+ offload_folder=None,
+ state_dict_index=None,
+ state_dict_folder=None,
+ ignore_mismatched_sizes=False,
+ low_cpu_mem_usage=False,
+):
+ # Do not spawn anymore workers than you need
+ num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
+
+ logger.info(f"Loading model weights in parallel with {num_workers} workers...")
+
+ error_msgs = []
+ mismatched_keys = []
+
+ load_one = functools.partial(
+ _load_shard_file,
+ model=model,
+ model_state_dict=model_state_dict,
+ device_map=device_map,
+ dtype=dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ dduf_entries=dduf_entries,
+ loaded_keys=loaded_keys,
+ unexpected_keys=unexpected_keys,
+ offload_index=offload_index,
+ offload_folder=offload_folder,
+ state_dict_index=state_dict_index,
+ state_dict_folder=state_dict_folder,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
+ futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
+ for future in as_completed(futures):
+ result = future.result()
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
+ error_msgs += _error_msgs
+ mismatched_keys += _mismatched_keys
+ pbar.update(1)
+
+ return offload_index, state_dict_index, mismatched_keys, error_msgs
+
+
+def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+ # If the checkpoint is sharded, we may not have the key here.
+ if checkpoint_key not in state_dict:
+ continue
+
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
@@ -520,3 +680,72 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
+
+
+def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
+ mismatched_keys = []
+ if not ignore_mismatched_sizes:
+ return mismatched_keys
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+ # If the checkpoint is sharded, we may not have the key here.
+ if checkpoint_key not in state_dict:
+ continue
+
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+
+def _expand_device_map(device_map, param_names):
+ """
+ Expand a device map to return the correspondence parameter name to device.
+ """
+ new_device_map = {}
+ for module, device in device_map.items():
+ new_device_map.update(
+ {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
+ )
+ return new_device_map
+
+
+# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
+def _caching_allocator_warmup(
+ model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
+) -> None:
+ """
+ This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
+ device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
+ which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
+ very large margin.
+ """
+ factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
+
+ # Keep only accelerator devices
+ accelerator_device_map = {
+ param: torch.device(device)
+ for param, device in expanded_device_map.items()
+ if str(device) not in ["cpu", "disk"]
+ }
+ if not accelerator_device_map:
+ return
+
+ elements_per_device = defaultdict(int)
+ for param_name, device in accelerator_device_map.items():
+ try:
+ p = model.get_parameter(param_name)
+ except AttributeError:
+ try:
+ p = model.get_buffer(param_name)
+ except AttributeError:
+ raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
+ # TODO: account for TP when needed.
+ elements_per_device[device] += p.numel()
+
+ # This will kick off the caching allocator to avoid having to Malloc afterwards
+ for device, elem_count in elements_per_device.items():
+ warmup_elems = max(1, elem_count // factor)
+ _ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py
index 52f004f6f93f..3f060993190f 100644
--- a/src/diffusers/models/modeling_flax_utils.py
+++ b/src/diffusers/models/modeling_flax_utils.py
@@ -26,11 +26,11 @@
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
-from requests import HTTPError
from .. import __version__, is_torch_available
from ..utils import (
@@ -113,14 +113,14 @@ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> params = model.to_bf16(params)
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
@@ -149,7 +149,7 @@ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> params = model.to_f16(params)
@@ -179,14 +179,14 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # By default, the model params will be in fp32, to cast these to float16
>>> params = model.to_fp16(params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
@@ -216,8 +216,8 @@ def from_pretrained(
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
- hosted on the Hub.
+ - A string, the *model id* (for example `stable-diffusion-v1-5/stable-diffusion-v1-5`) of a
+ pretrained model hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
using [`~FlaxModelMixin.save_pretrained`].
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
@@ -227,15 +227,9 @@ def from_pretrained(
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified, all the computation will be performed with the given `dtype`.
-
-
- This only specifies the dtype of the *computation* and does not influence the dtype of model
- parameters.
-
- If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
- [`~FlaxModelMixin.to_bf16`].
-
-
+ > [!TIP] > This only specifies the dtype of the *computation* and does not influence the dtype of model
+ > parameters. > > If you wish to change the dtype of the model parameters, see
+ [`~FlaxModelMixin.to_fp16`] and > [`~FlaxModelMixin.to_bf16`].
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments are passed to the underlying model's `__init__` method.
@@ -277,7 +271,7 @@ def from_pretrained(
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co and cache.
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
```
@@ -285,11 +279,15 @@ def from_pretrained(
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
@@ -369,8 +367,7 @@ def from_pretrained(
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
- "token having permission to this repo with `token` or log in with `huggingface-cli "
- "login`."
+ "token having permission to this repo with `token` or log in with `hf auth login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
@@ -382,7 +379,7 @@ def from_pretrained(
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
)
- except HTTPError as err:
+ except HfHubHTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 19ac868cdae0..41da95d3a2a2 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -15,6 +15,7 @@
# limitations under the License.
import copy
+import functools
import inspect
import itertools
import json
@@ -42,6 +43,7 @@
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
+ HF_ENABLE_PARALLEL_LOADING,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
@@ -62,12 +64,16 @@
load_or_create_model_card,
populate_model_card,
)
+from ..utils.torch_utils import empty_device_cache
+from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
+ _caching_allocator_warmup,
_determine_device_map,
+ _expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
- _load_state_dict_into_model,
- load_model_dict_into_meta,
+ _load_shard_file,
+ _load_shard_files_with_threadpool,
load_state_dict,
)
@@ -168,7 +174,11 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
for name, param in parameter.named_parameters():
last_dtype = param.dtype
- if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
+ if (
+ hasattr(parameter, "_keep_in_fp32_modules")
+ and parameter._keep_in_fp32_modules
+ and any(m in name for m in parameter._keep_in_fp32_modules)
+ ):
continue
if param.is_floating_point():
@@ -200,34 +210,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return last_tuple[1].dtype
-def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
- """
- Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
- checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
- parameters.
-
- """
- if model_to_load.device.type == "meta":
- return False
-
- if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
- return False
-
- # Some models explicitly do not support param buffer assignment
- if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
- logger.debug(
- f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
- )
- return False
-
- # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
- first_key = next(iter(model_to_load.state_dict().keys()))
- if start_prefix + first_key in state_dict:
- return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
-
- return False
-
-
@contextmanager
def no_init_weights():
"""
@@ -266,6 +248,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
+ _repeated_blocks = []
+ _parallel_config = None
+ _cp_plan = None
+ _skip_keys = None
def __init__(self):
super().__init__()
@@ -418,12 +404,8 @@ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Call
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
inference. Speed up during training is not guaranteed.
-
-
- ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
- precedent.
-
-
+ > [!WARNING] > ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient
+ attention takes > precedent.
Parameters:
attention_op (`Callable`, *optional*):
@@ -546,7 +528,11 @@ def enable_group_offload(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
+ record_stream: bool = False,
low_cpu_mem_usage=False,
+ offload_to_disk_path: Optional[str] = None,
+ block_modules: Optional[str] = None,
+ exclude_kwargs: Optional[str] = None,
) -> None:
r"""
Activates group offloading for the current model.
@@ -586,17 +572,82 @@ def enable_group_offload(
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
+
apply_group_offloading(
- self,
- onload_device,
- offload_device,
- offload_type,
- num_blocks_per_group,
- non_blocking,
- use_stream,
+ module=self,
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type=offload_type,
+ num_blocks_per_group=num_blocks_per_group,
+ non_blocking=non_blocking,
+ use_stream=use_stream,
+ record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
+ offload_to_disk_path=offload_to_disk_path,
+ block_modules=block_modules,
+ exclude_kwargs=exclude_kwargs,
+ )
+
+ def set_attention_backend(self, backend: str) -> None:
+ """
+ Set the attention backend for the model.
+
+ Args:
+ backend (`str`):
+ The name of the backend to set. Must be one of the available backends defined in
+ `AttentionBackendName`. Available backends can be found in
+ `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
+ attention as backend.
+ """
+ from .attention import AttentionModuleMixin
+ from .attention_dispatch import (
+ AttentionBackendName,
+ _check_attention_backend_requirements,
+ _maybe_download_kernel_for_backend,
)
+ # TODO: the following will not be required when everything is refactored to AttentionModuleMixin
+ from .attention_processor import Attention, MochiAttention
+
+ logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
+
+ backend = backend.lower()
+ available_backends = {x.value for x in AttentionBackendName.__members__.values()}
+ if backend not in available_backends:
+ raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
+
+ backend = AttentionBackendName(backend)
+ _check_attention_backend_requirements(backend)
+ _maybe_download_kernel_for_backend(backend)
+
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_attention_backend"):
+ continue
+ processor._attention_backend = backend
+
+ def reset_attention_backend(self) -> None:
+ """
+ Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
+ set, or the torch native scaled dot product attention.
+ """
+ from .attention import AttentionModuleMixin
+ from .attention_processor import Attention, MochiAttention
+
+ logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
+
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_attention_backend"):
+ continue
+ processor._attention_backend = None
+
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -785,9 +836,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
- dtype is automatically derived from the model's weights.
+ torch_dtype (`torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -813,14 +863,43 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+ device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
+ Examples:
+
+ ```py
+ >>> from diffusers import AutoModel
+ >>> import torch
+
+ >>> # This works.
+ >>> model = AutoModel.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
+ ... )
+ >>> # This also works (integer accelerator device ID).
+ >>> model = AutoModel.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
+ ... )
+ >>> # Specifying a supported offloading strategy like "auto" also works.
+ >>> model = AutoModel.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
+ ... )
+ >>> # Specifying a dictionary as `device_map` also works.
+ >>> model = AutoModel.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
+ ... subfolder="unet",
+ ... device_map={"": torch.device("cuda")},
+ ... )
+ ```
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
more information about each option see [designing a device
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
+ can also refer to the [Diffusers-specific
+ documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
+ for more concrete examples.
max_memory (`Dict`, *optional*):
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
each GPU and the available CPU RAM if unset.
@@ -846,27 +925,23 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
Example:
```py
from diffusers import UNet2DConditionModel
- unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
```
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -892,6 +967,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
+ parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
+
+ is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
+ if is_parallel_loading_enabled and not low_cpu_mem_usage:
+ raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
@@ -1228,6 +1308,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
+ is_parallel_loading_enabled=is_parallel_loading_enabled,
)
loading_info = {
"missing_keys": missing_keys,
@@ -1267,6 +1348,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
+ if parallel_config is not None:
+ model.enable_parallelism(config=parallel_config)
+
if output_loading_info:
return model, loading_info
@@ -1372,6 +1456,126 @@ def float(self, *args):
else:
return super().float(*args)
+ def compile_repeated_blocks(self, *args, **kwargs):
+ """
+ Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
+ compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
+ https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
+ substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
+
+ The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
+ model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
+ module whose class name matches will be compiled.
+
+ Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
+ positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
+ `torch.compile`.
+ """
+ repeated_blocks = getattr(self, "_repeated_blocks", None)
+
+ if not repeated_blocks:
+ raise ValueError(
+ "`_repeated_blocks` attribute is empty. "
+ f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
+ )
+ has_compiled_region = False
+ for submod in self.modules():
+ if submod.__class__.__name__ in repeated_blocks:
+ submod.compile(*args, **kwargs)
+ has_compiled_region = True
+
+ if not has_compiled_region:
+ raise ValueError(
+ f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
+ )
+
+ def enable_parallelism(
+ self,
+ *,
+ config: Union[ParallelConfig, ContextParallelConfig],
+ cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
+ ):
+ logger.warning(
+ "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
+ )
+
+ if not torch.distributed.is_available() and not torch.distributed.is_initialized():
+ raise RuntimeError(
+ "torch.distributed must be available and initialized before calling `enable_parallelism`."
+ )
+
+ from ..hooks.context_parallel import apply_context_parallel
+ from .attention import AttentionModuleMixin
+ from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
+ from .attention_processor import Attention, MochiAttention
+
+ if isinstance(config, ContextParallelConfig):
+ config = ParallelConfig(context_parallel_config=config)
+
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+ device_type = torch._C._get_accelerator().type
+ device_module = torch.get_device_module(device_type)
+ device = torch.device(device_type, rank % device_module.device_count())
+
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+
+ if config.context_parallel_config is not None:
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_attention_backend"):
+ continue
+
+ attention_backend = processor._attention_backend
+ if attention_backend is None:
+ attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
+ else:
+ attention_backend = AttentionBackendName(attention_backend)
+
+ if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
+ compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
+ raise ValueError(
+ f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
+ f"is using backend '{attention_backend.value}' which does not support context parallelism. "
+ f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
+ f"calling `enable_parallelism()`."
+ )
+
+ # All modules use the same attention processor and backend. We don't need to
+ # iterate over all modules after checking the first processor
+ break
+
+ mesh = None
+ if config.context_parallel_config is not None:
+ cp_config = config.context_parallel_config
+ mesh = torch.distributed.device_mesh.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=cp_config.mesh_shape,
+ mesh_dim_names=cp_config.mesh_dim_names,
+ )
+
+ config.setup(rank, world_size, device, mesh=mesh)
+ self._parallel_config = config
+
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_parallel_config"):
+ continue
+ processor._parallel_config = config
+
+ if config.context_parallel_config is not None:
+ if cp_plan is None and self._cp_plan is None:
+ raise ValueError(
+ "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
+ )
+ cp_plan = cp_plan if cp_plan is not None else self._cp_plan
+ apply_context_parallel(self, config.context_parallel_config, cp_plan)
+
@classmethod
def _load_pretrained_model(
cls,
@@ -1386,10 +1590,11 @@ def _load_pretrained_model(
low_cpu_mem_usage: bool = True,
dtype: Optional[Union[str, torch.dtype]] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
- device_map: Dict[str, Union[int, str, torch.device]] = None,
+ device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
+ is_parallel_loading_enabled: Optional[bool] = False,
):
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
@@ -1404,8 +1609,6 @@ def _load_pretrained_model(
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
mismatched_keys = []
-
- assign_to_params_buffers = None
error_msgs = []
# Deal with offload
@@ -1416,80 +1619,67 @@ def _load_pretrained_model(
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
- if offload_folder is not None:
+ else:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
+ # If a device map has been used, we can speedup the load time by warming up the device caching allocator.
+ # If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
+ # lot of individual calls to device malloc). We can, however, preallocate the memory required by the
+ # tensors using their expected shape and not performing any initialization of the memory (empty data).
+ # When the actual device allocations happen, the allocator already has a pool of unused device memory
+ # that it can re-use for faster loading of the model.
+ if device_map is not None:
+ expanded_device_map = _expand_device_map(device_map, expected_keys)
+ _caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
+
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
+ state_dict_folder, state_dict_index = None, None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
- else:
- state_dict_folder = None
- state_dict_index = None
if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
resolved_model_file = [state_dict]
- if len(resolved_model_file) > 1:
- resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
-
- for shard_file in resolved_model_file:
- state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
-
- def _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
- ):
- mismatched_keys = []
- if ignore_mismatched_sizes:
- for checkpoint_key in loaded_keys:
- model_key = checkpoint_key
- # If the checkpoint is sharded, we may not have the key here.
- if checkpoint_key not in state_dict:
- continue
-
- if (
- model_key in model_state_dict
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
- ):
- mismatched_keys.append(
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
- )
- del state_dict[checkpoint_key]
- return mismatched_keys
-
- mismatched_keys += _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
- )
+ # Prepare the loading function sharing the attributes shared between them.
+ load_fn = functools.partial(
+ _load_shard_files_with_threadpool if is_parallel_loading_enabled else _load_shard_file,
+ model=model,
+ model_state_dict=model_state_dict,
+ device_map=device_map,
+ dtype=dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ dduf_entries=dduf_entries,
+ loaded_keys=loaded_keys,
+ unexpected_keys=unexpected_keys,
+ offload_index=offload_index,
+ offload_folder=offload_folder,
+ state_dict_index=state_dict_index,
+ state_dict_folder=state_dict_folder,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
- if low_cpu_mem_usage:
- offload_index, state_dict_index = load_model_dict_into_meta(
- model,
- state_dict,
- device_map=device_map,
- dtype=dtype,
- hf_quantizer=hf_quantizer,
- keep_in_fp32_modules=keep_in_fp32_modules,
- unexpected_keys=unexpected_keys,
- offload_folder=offload_folder,
- offload_index=offload_index,
- state_dict_index=state_dict_index,
- state_dict_folder=state_dict_folder,
- )
- else:
- if assign_to_params_buffers is None:
- assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
+ if is_parallel_loading_enabled:
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(resolved_model_file)
+ error_msgs += _error_msgs
+ mismatched_keys += _mismatched_keys
+ else:
+ shard_files = resolved_model_file
+ if len(resolved_model_file) > 1:
+ shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
+
+ for shard_file in shard_files:
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
+ error_msgs += _error_msgs
+ mismatched_keys += _mismatched_keys
- error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
+ empty_device_cache()
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
@@ -1642,7 +1832,7 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
```py
from diffusers import UNet2DConditionModel
- model_id = "runwayml/stable-diffusion-v1-5"
+ model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
unet.num_parameters(only_trainable=True)
859520964
@@ -1826,4 +2016,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
- return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
+ if remapped_class is cls:
+ return super(LegacyModelMixin, remapped_class).from_pretrained(
+ pretrained_model_name_or_path, **kwargs_copy
+ )
+ else:
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 962ce435bdb7..ae2a6298f5f7 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -237,7 +237,7 @@ class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
- As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
+ As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
@@ -510,7 +510,7 @@ def forward(self, input):
class RMSNorm(nn.Module):
r"""
- RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
+ RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.
Args:
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
@@ -600,7 +600,7 @@ def forward(self, hidden_states):
class GlobalResponseNorm(nn.Module):
r"""
- Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).
+ Global response normalization as introduced in ConvNeXt-v2 (https://huggingface.co/papers/2301.00808).
Args:
dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 260b4b8929b0..c0b4ad40055a 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -1,5 +1,5 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-# `TemporalConvLayer` Copyright 2024 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+# `TemporalConvLayer` Copyright 2025 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/resnet_flax.py b/src/diffusers/models/resnet_flax.py
index f8bb4788d9a5..9bedaa9a36b6 100644
--- a/src/diffusers/models/resnet_flax.py
+++ b/src/diffusers/models/resnet_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,12 +15,22 @@
import jax
import jax.numpy as jnp
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 5392935da02b..a42f6b2716e1 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -17,15 +17,34 @@
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
+ from .transformer_bria import BriaTransformer2DModel
+ from .transformer_bria_fibo import BriaFiboTransformer2DModel
+ from .transformer_chroma import ChromaTransformer2DModel
+ from .transformer_chronoedit import ChronoEditTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
+ from .transformer_cosmos import CosmosTransformer3DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
+ from .transformer_flux2 import Flux2Transformer2DModel
+ from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
+ from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
+ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
+ from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
+ from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
+ from .transformer_ovis_image import OvisImageTransformer2DModel
+ from .transformer_prx import PRXTransformer2DModel
+ from .transformer_qwenimage import QwenImageTransformer2DModel
+ from .transformer_sana_video import SanaVideoTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
+ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
+ from .transformer_wan_animate import WanAnimateTransformer3DModel
+ from .transformer_wan_vace import WanVACETransformer3DModel
+ from .transformer_z_image import ZImageTransformer2DModel
diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py
index 4938ed23c506..e3732662e408 100644
--- a/src/diffusers/models/transformers/auraflow_transformer_2d.py
+++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 AuraFlow Authors, The HuggingFace Team. All rights reserved.
+# Copyright 2025 AuraFlow Authors, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,19 +13,19 @@
# limitations under the License.
-from typing import Dict, Union
+from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import FromOriginalModelMixin
-from ...utils import logging
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin
from ..attention_processor import (
Attention,
- AttentionProcessor,
AuraFlowAttnProcessor2_0,
FusedAuraFlowAttnProcessor2_0,
)
@@ -74,17 +74,25 @@ def pe_selection_index_based_on_dim(self, h, w):
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size
- original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
- original_pe_indexes = original_pe_indexes.view(h_max, w_max)
+
+ # Calculate the top-left corner indices for the centered patch grid
starth = h_max // 2 - h_p // 2
- endh = starth + h_p
startw = w_max // 2 - w_p // 2
- endw = startw + w_p
- original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
- return original_pe_indexes.flatten()
- def forward(self, latent):
+ # Generate the row and column indices for the desired patch grid
+ rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
+ cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
+
+ # Create a 2D grid of indices
+ row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
+
+ # Convert the 2D grid indices to flattened 1D indices
+ selected_indices = (row_indices * w_max + col_indices).flatten()
+
+ return selected_indices
+
+ def forward(self, latent) -> torch.Tensor:
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
@@ -160,14 +168,20 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
- def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
residual = hidden_states
+ attention_kwargs = attention_kwargs or {}
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
- attn_output = self.attn(hidden_states=norm_hidden_states)
+ attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
@@ -223,10 +237,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
- ):
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ temb: torch.FloatTensor,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
residual_context = encoder_hidden_states
+ attention_kwargs = attention_kwargs or {}
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -236,7 +255,9 @@ def forward(
# Attention.
attn_output, context_attn_output = self.attn(
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ **attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -254,7 +275,7 @@ def forward(
return encoder_hidden_states, hidden_states
-class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
@@ -262,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
- num_single_dit_layers (`int`, *optional*, defaults to 4):
+ num_single_dit_layers (`int`, *optional*, defaults to 32):
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
representations.
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
- out_channels (`int`, defaults to 16): Number of output channels.
- pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
+ out_channels (`int`, defaults to 4): Number of output channels.
+ pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
"""
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
@@ -338,83 +359,19 @@ def __init__(
self.norm_out = AuraFlowPreFinalBlock(self.inner_dim, self.inner_dim)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
- # https://arxiv.org/abs/2309.16588
+ # https://huggingface.co/papers/2309.16588
# prevents artifacts in the attention maps
self.register_tokens = nn.Parameter(torch.randn(1, 8, self.inner_dim) * 0.02)
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -434,11 +391,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
@@ -449,8 +402,24 @@ def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -474,7 +443,10 @@ def forward(
else:
encoder_hidden_states, hidden_states = block(
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ attention_kwargs=attention_kwargs,
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -491,7 +463,9 @@ def forward(
)
else:
- combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
+ combined_hidden_states = block(
+ hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
+ )
hidden_states = combined_hidden_states[:, encoder_seq_len:]
@@ -512,6 +486,10 @@ def forward(
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
if not return_dict:
return (output,)
diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
index 6b4f38dc04a1..14b38cd46c52 100644
--- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py
+++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,8 +22,8 @@
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import Attention, FeedForward
-from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from ..attention import Attention, AttentionMixin, FeedForward
+from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
@@ -122,7 +122,7 @@ def forward(
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
@@ -157,7 +157,7 @@ def forward(
return hidden_states, encoder_hidden_states
-class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
+class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
@@ -331,77 +331,13 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -421,11 +357,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
@@ -441,7 +373,7 @@ def forward(
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py
index f312553e4c05..be20b0a3eacf 100644
--- a/src/diffusers/models/transformers/consisid_transformer_3d.py
+++ b/src/diffusers/models/transformers/consisid_transformer_3d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,8 +22,8 @@
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import Attention, FeedForward
-from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
+from ..attention import Attention, AttentionMixin, FeedForward
+from ..attention_processor import CogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -315,7 +315,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -348,7 +348,7 @@ def forward(
return hidden_states, encoder_hidden_states
-class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
"""
A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
@@ -620,66 +620,6 @@ def _init_face_inputs(self):
]
)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -691,7 +631,7 @@ def forward(
id_cond: Optional[torch.Tensor] = None,
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/dit_transformer_2d.py b/src/diffusers/models/transformers/dit_transformer_2d.py
index cdc0738050e4..68f6f769436e 100644
--- a/src/diffusers/models/transformers/dit_transformer_2d.py
+++ b/src/diffusers/models/transformers/dit_transformer_2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,7 +30,7 @@
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
r"""
- A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
+ A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).
Parameters:
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
diff --git a/src/diffusers/models/transformers/dual_transformer_2d.py b/src/diffusers/models/transformers/dual_transformer_2d.py
index 1c48c4e3db79..24eed2168229 100644
--- a/src/diffusers/models/transformers/dual_transformer_2d.py
+++ b/src/diffusers/models/transformers/dual_transformer_2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
index 550cc6d9d1e5..cecb675b32b7 100644
--- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py
+++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
+# Copyright 2025 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Optional
import torch
from torch import nn
@@ -19,8 +19,8 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
-from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
+from ..attention import AttentionMixin, FeedForward
+from ..attention_processor import Attention, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
@@ -200,7 +200,7 @@ def forward(
return hidden_states
-class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
+class HunyuanDiT2DModel(ModelMixin, AttentionMixin, ConfigMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
@@ -308,7 +308,7 @@ def __init__(
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
- qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
+ qk_norm=True, # See https://huggingface.co/papers/2302.05442 for details.
skip=layer > num_layers // 2,
)
for layer in range(num_layers)
@@ -324,11 +324,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -348,76 +344,12 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py
index 132c258455ea..990c90512e39 100644
--- a/src/diffusers/models/transformers/latte_transformer_3d.py
+++ b/src/diffusers/models/transformers/latte_transformer_3d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 the Latte Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 the Latte Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,10 +18,9 @@
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock
from ..cache_utils import CacheMixin
-from ..embeddings import PatchEmbed
+from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
@@ -31,7 +30,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True
"""
- A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code:
+ A 3D Transformer model for video-like data, paper: https://huggingface.co/papers/2401.03048, official code:
https://github.com/Vchitect/Latte
Parameters:
@@ -217,7 +216,7 @@ def forward(
)
num_patches = height * width
- hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
+ hidden_states = self.pos_embed(hidden_states) # already add positional embeddings
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
timestep, embedded_timestep = self.adaln_single(
diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py
index 320950866c4a..bed5e69c2d36 100644
--- a/src/diffusers/models/transformers/lumina_nextdit2d.py
+++ b/src/diffusers/models/transformers/lumina_nextdit2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -43,7 +43,7 @@ class LuminaNextDiTBlock(nn.Module):
num_kv_heads (`int`):
Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
multiple_of (`int`): The number of multiple of ffn layer.
- ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension.
+ ffn_dim_multiplier (`float`): The multiplier factor of ffn layer dimension.
norm_eps (`float`): The eps for norm layer.
qk_norm (`bool`): normalization for query and key.
cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
@@ -124,7 +124,7 @@ def forward(
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> torch.Tensor:
"""
Perform a forward pass through the LuminaNextDiTBlock.
@@ -297,7 +297,7 @@ def forward(
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
- ) -> torch.Tensor:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
Forward pass of LuminaNextDiT.
diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py
index 8e290074a018..072670ee0c30 100644
--- a/src/diffusers/models/transformers/pixart_transformer_2d.py
+++ b/src/diffusers/models/transformers/pixart_transformer_2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
-from ..attention import BasicTransformerBlock
-from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
+from ..attention import AttentionMixin, BasicTransformerBlock
+from ..attention_processor import Attention, AttnProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -29,10 +29,10 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
+class PixArtTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin):
r"""
- A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
- https://arxiv.org/abs/2403.04692).
+ A 2D Transformer model as introduced in PixArt family of models (https://huggingface.co/papers/2310.00426,
+ https://huggingface.co/papers/2403.04692).
Parameters:
num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
@@ -184,66 +184,6 @@ def __init__(
in_features=self.config.caption_channels, hidden_size=self.inner_dim
)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -258,11 +198,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -282,11 +218,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/transformers/prior_transformer.py b/src/diffusers/models/transformers/prior_transformer.py
index 24d4e4d3d76f..757bb436040f 100644
--- a/src/diffusers/models/transformers/prior_transformer.py
+++ b/src/diffusers/models/transformers/prior_transformer.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
-from typing import Dict, Optional, Union
+from typing import Optional, Union
import torch
import torch.nn.functional as F
@@ -8,11 +8,10 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput
-from ..attention import BasicTransformerBlock
+from ..attention import AttentionMixin, BasicTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
@@ -33,7 +32,7 @@ class PriorTransformerOutput(BaseOutput):
predicted_image_embedding: torch.Tensor
-class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+class PriorTransformer(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
"""
A Prior Transformer model.
@@ -61,7 +60,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
product between the text embedding and image embedding as proposed in the unclip paper
- https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
+ https://huggingface.co/papers/2204.06125 If it is `None`, no additional embeddings will be prepended.
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
If None, will be set to `num_attention_heads * attention_head_dim`
embedding_proj_dim (`int`, *optional*, default to None):
@@ -166,66 +165,6 @@ def __init__(
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py
index 54e996e13d42..69e57ad6e429 100644
--- a/src/diffusers/models/transformers/sana_transformer.py
+++ b/src/diffusers/models/transformers/sana_transformer.py
@@ -21,9 +21,9 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin
from ..attention_processor import (
Attention,
- AttentionProcessor,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import (
@@ -338,9 +338,7 @@ def forward(
return hidden_states
-class SanaTransformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
-):
+class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
@@ -469,72 +467,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(
- name: str,
- module: torch.nn.Module,
- processors: Dict[str, AttentionProcessor],
- ):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
- ):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -626,12 +558,8 @@ def forward(
post_patch_height,
post_patch_width,
)
- if controlnet_block_samples is not None and 0 < index_block <= len(
- controlnet_block_samples
- ):
- hidden_states = (
- hidden_states + controlnet_block_samples[index_block - 1]
- )
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
else:
for index_block, block in enumerate(self.transformer_blocks):
@@ -644,12 +572,8 @@ def forward(
post_patch_height,
post_patch_width,
)
- if controlnet_block_samples is not None and 0 < index_block <= len(
- controlnet_block_samples
- ):
- hidden_states = (
- hidden_states + controlnet_block_samples[index_block - 1]
- )
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
# 3. Normalization
hidden_states = self.norm_out(
diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py
index d81b6447adb0..2c3b6b5df91d 100644
--- a/src/diffusers/models/transformers/stable_audio_transformer.py
+++ b/src/diffusers/models/transformers/stable_audio_transformer.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,24 +13,19 @@
# limitations under the License.
-from typing import Dict, Optional, Union
+from typing import Optional, Union
import numpy as np
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
-from ...models.attention import FeedForward
-from ...models.attention_processor import (
- Attention,
- AttentionProcessor,
- StableAudioAttnProcessor2_0,
-)
-from ...models.modeling_utils import ModelMixin
-from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin, FeedForward
+from ..attention_processor import Attention, StableAudioAttnProcessor2_0
+from ..modeling_utils import ModelMixin
+from ..transformers.transformer_2d import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -187,7 +182,7 @@ def forward(
return hidden_states
-class StableAudioDiTModel(ModelMixin, ConfigMixin):
+class StableAudioDiTModel(ModelMixin, AttentionMixin, ConfigMixin):
"""
The Diffusion Transformer model introduced in Stable Audio.
@@ -279,66 +274,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/models/transformers/t5_film_transformer.py b/src/diffusers/models/transformers/t5_film_transformer.py
index 1dea37a25910..7a9608735e32 100644
--- a/src/diffusers/models/transformers/t5_film_transformer.py
+++ b/src/diffusers/models/transformers/t5_film_transformer.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -390,7 +390,7 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
- # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
+ # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
@@ -407,7 +407,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
- the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py
index a88ee6c9c9b8..67fe9a33109b 100644
--- a/src/diffusers/models/transformers/transformer_2d.py
+++ b/src/diffusers/models/transformers/transformer_2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -211,9 +211,9 @@ def _init_continuous_input(self, norm_type):
def _init_vectorized_inputs(self, norm_type):
assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
- assert (
- self.config.num_vector_embeds is not None
- ), "Transformer2DModel over discrete input must provide num_embed"
+ assert self.config.num_vector_embeds is not None, (
+ "Transformer2DModel over discrete input must provide num_embed"
+ )
self.height = self.config.sample_size
self.width = self.config.sample_size
diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py
index d5c93409c932..5fa59a71d977 100644
--- a/src/diffusers/models/transformers/transformer_allegro.py
+++ b/src/diffusers/models/transformers/transformer_allegro.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The RhymesAI and The HuggingFace Team.
+# Copyright 2025 The RhymesAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py
new file mode 100644
index 000000000000..d54679306e64
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_bria.py
@@ -0,0 +1,725 @@
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[np.ndarray, int],
+ theta: float = 10000.0,
+ use_real=False,
+ linear_factor=1.0,
+ ntk_factor=1.0,
+ repeat_interleave_real=True,
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
+ data type.
+
+ Args:
+ dim (`int`): Dimension of the frequency tensor.
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
+ theta (`float`, *optional*, defaults to 10000.0):
+ Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (`bool`, *optional*):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ linear_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the context extrapolation. Defaults to 1.0.
+ ntk_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
+ Otherwise, they are concateanted with themselves.
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
+ the dtype of the frequency tensor.
+ Returns:
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
+ """
+ assert dim % 2 == 0
+
+ if isinstance(pos, int):
+ pos = torch.arange(pos)
+ if isinstance(pos, np.ndarray):
+ pos = torch.from_numpy(pos) # type: ignore # [S]
+
+ theta = theta * ntk_factor
+ freqs = (
+ 1.0
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
+ / linear_factor
+ ) # [D/2]
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
+ if use_real and repeat_interleave_real:
+ # bria
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ elif use_real:
+ # stable audio, allegro
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ # lumina
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+class BriaAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "BriaAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class BriaAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = BriaAttnProcessor
+ _available_processors = [
+ BriaAttnProcessor,
+ ]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ context_pre_only: Optional[bool] = None,
+ pre_only: bool = False,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class BriaEmbedND(torch.nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+class BriaTimesteps(nn.Module):
+ def __init__(
+ self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.scale = scale
+ self.time_theta = time_theta
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ scale=self.scale,
+ max_period=self.time_theta,
+ )
+ return t_emb
+
+
+class BriaTimestepProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, time_theta):
+ super().__init__()
+
+ self.time_proj = BriaTimesteps(
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
+ )
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timestep, dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
+ return timesteps_emb
+
+
+class BriaPosEmbed(torch.nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
+class BriaTransformerBlock(nn.Module):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
+ super().__init__()
+
+ self.norm1 = AdaLayerNormZero(dim)
+ self.norm1_context = AdaLayerNormZero(dim)
+
+ self.attn = BriaAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=BriaAttnProcessor(),
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+ attention_kwargs = attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+@maybe_allow_in_graph
+class BriaSingleTransformerBlock(nn.Module):
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm = AdaLayerNormZeroSingle(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ processor = BriaAttnProcessor()
+
+ self.attn = BriaAttention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ attention_kwargs = attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **attention_kwargs,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+
+
+class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ """
+ The Transformer model introduced in Flux. Based on FluxPipeline with several changes:
+ - no pooled embeddings
+ - We use zero padding for prompts
+ - No guidance embedding since this is not a distilled version
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Parameters:
+ patch_size (`int`): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = None,
+ guidance_embeds: bool = False,
+ axes_dims_rope: List[int] = [16, 56, 56],
+ rope_theta=10000,
+ time_theta=10000,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+
+ self.pos_embed = BriaEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
+
+ self.time_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
+ if guidance_embeds:
+ self.guidance_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim)
+
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BriaTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ BriaSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_single_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
+ """
+ The [`BriaTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype)
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype)
+ else:
+ guidance = None
+
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
+
+ if guidance:
+ temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if len(txt_ids.shape) == 3:
+ txt_ids = txt_ids[0]
+
+ if len(img_ids.shape) == 3:
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ attention_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ attention_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ + controlnet_single_block_samples[index_block // interval_control]
+ )
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py
new file mode 100644
index 000000000000..09f79619320d
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_bria_fibo.py
@@ -0,0 +1,655 @@
+# Copyright (c) Bria.ai. All rights reserved.
+#
+# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
+# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
+#
+# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
+# indicate if changes were made, and do not use the material for commercial purposes.
+#
+# See the license for further details.
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...models.attention_processor import Attention
+from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
+from ...models.modeling_outputs import Transformer2DModelOutput
+from ...models.modeling_utils import ModelMixin
+from ...models.transformers.transformer_bria import BriaAttnProcessor
+from ...utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention
+class BriaFiboAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "BriaFiboAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py
+class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = BriaFiboAttnProcessor
+ _available_processors = [BriaFiboAttnProcessor]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ context_pre_only: Optional[bool] = None,
+ pre_only: bool = False,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class BriaFiboEmbedND(torch.nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
+class BriaFiboSingleTransformerBlock(nn.Module):
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm = AdaLayerNormZeroSingle(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ processor = BriaAttnProcessor()
+
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ qk_norm="rms_norm",
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states
+
+
+class BriaFiboTextProjection(nn.Module):
+ def __init__(self, in_features, hidden_size):
+ super().__init__()
+ self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
+
+ def forward(self, caption):
+ hidden_states = self.linear(caption)
+ return hidden_states
+
+
+@maybe_allow_in_graph
+# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
+class BriaFiboTransformerBlock(nn.Module):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
+ super().__init__()
+
+ self.norm1 = AdaLayerNormZero(dim)
+ self.norm1_context = AdaLayerNormZero(dim)
+
+ self.attn = BriaFiboAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=BriaFiboAttnProcessor(),
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class BriaFiboTimesteps(nn.Module):
+ def __init__(
+ self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.scale = scale
+ self.time_theta = time_theta
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ scale=self.scale,
+ max_period=self.time_theta,
+ )
+ return t_emb
+
+
+class BriaFiboTimestepProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, time_theta):
+ super().__init__()
+
+ self.time_proj = BriaFiboTimesteps(
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
+ )
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timestep, dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
+ return timesteps_emb
+
+
+class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ """
+ Parameters:
+ patch_size (`int`): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ ...
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = None,
+ guidance_embeds: bool = False,
+ axes_dims_rope: List[int] = [16, 56, 56],
+ rope_theta=10000,
+ time_theta=10000,
+ text_encoder_dim: int = 2048,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+
+ self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
+
+ self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
+
+ if guidance_embeds:
+ self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
+
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BriaFiboTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ BriaFiboSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_single_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ caption_projection = [
+ BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
+ for i in range(self.config.num_layers + self.config.num_single_layers)
+ ]
+ self.caption_projection = nn.ModuleList(caption_projection)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ text_encoder_layers: list = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype)
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype)
+ else:
+ guidance = None
+
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
+
+ if guidance:
+ temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if len(txt_ids.shape) == 3:
+ txt_ids = txt_ids[0]
+
+ if len(img_ids.shape) == 3:
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ new_text_encoder_layers = []
+ for i, text_encoder_layer in enumerate(text_encoder_layers):
+ text_encoder_layer = self.caption_projection[i](text_encoder_layer)
+ new_text_encoder_layers.append(text_encoder_layer)
+ text_encoder_layers = new_text_encoder_layers
+
+ block_id = 0
+ for index_block, block in enumerate(self.transformer_blocks):
+ current_text_encoder_layer = text_encoder_layers[block_id]
+ encoder_hidden_states = torch.cat(
+ [encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
+ )
+ block_id += 1
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ joint_attention_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ current_text_encoder_layer = text_encoder_layers[block_id]
+ encoder_hidden_states = torch.cat(
+ [encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
+ )
+ block_id += 1
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ joint_attention_kwargs,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py
new file mode 100644
index 000000000000..2ef3643dafbd
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_chroma.py
@@ -0,0 +1,641 @@
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and loadstone-rock . All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.import_utils import is_torch_npu_available
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin, FeedForward
+from ..cache_utils import CacheMixin
+from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
+from .transformer_flux import FluxAttention, FluxAttnProcessor
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ChromaAdaLayerNormZeroPruned(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
+ super().__init__()
+ if num_embeddings is not None:
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
+ else:
+ self.emb = None
+
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ elif norm_type == "fp32_layer_norm":
+ self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timestep: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ hidden_dtype: Optional[torch.dtype] = None,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ if self.emb is not None:
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.flatten(1, 2).chunk(6, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+
+class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (adaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
+ super().__init__()
+
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ shift_msa, scale_msa, gate_msa = emb.flatten(1, 2).chunk(3, dim=1)
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ return x, gate_msa
+
+
+class ChromaAdaLayerNormContinuousPruned(nn.Module):
+ r"""
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
+
+ Args:
+ embedding_dim (`int`): Embedding dimension to use during projection.
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
+ elementwise_affine (`bool`, defaults to `True`):
+ Boolean flag to denote if affine transformation should be applied.
+ eps (`float`, defaults to 1e-5): Epsilon factor.
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
+ norm_type (`str`, defaults to `"layer_norm"`):
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
+ # However, this is how it was implemented in the original code, and it's rather likely you should
+ # set `elementwise_affine` to False.
+ elementwise_affine=True,
+ eps=1e-5,
+ bias=True,
+ norm_type="layer_norm",
+ ):
+ super().__init__()
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
+ shift, scale = torch.chunk(emb.flatten(1, 2).to(x.dtype), 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
+class ChromaCombinedTimestepTextProjEmbeddings(nn.Module):
+ def __init__(self, num_channels: int, out_dim: int):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+
+ self.register_buffer(
+ "mod_proj",
+ get_timestep_embedding(
+ torch.arange(out_dim) * 1000, 2 * num_channels, flip_sin_to_cos=True, downscale_freq_shift=0
+ ),
+ persistent=False,
+ )
+
+ def forward(self, timestep: torch.Tensor) -> torch.Tensor:
+ mod_index_length = self.mod_proj.shape[0]
+ batch_size = timestep.shape[0]
+
+ timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
+ guidance_proj = self.guidance_proj(torch.tensor([0] * batch_size)).to(
+ dtype=timestep.dtype, device=timestep.device
+ )
+
+ mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device).repeat(batch_size, 1, 1)
+ timestep_guidance = (
+ torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
+ )
+ input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
+ return input_vec.to(timestep.dtype)
+
+
+class ChromaApproximator(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
+ super().__init__()
+ self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.layers = nn.ModuleList(
+ [PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
+ )
+ self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
+ self.out_proj = nn.Linear(hidden_dim, out_dim)
+
+ def forward(self, x):
+ x = self.in_proj(x)
+
+ for layer, norms in zip(self.layers, self.norms):
+ x = x + layer(norms(x))
+
+ return self.out_proj(x)
+
+
+@maybe_allow_in_graph
+class ChromaSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ ):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+ self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ if is_torch_npu_available():
+ from ..attention_processor import FluxAttnProcessor2_0_NPU
+
+ deprecation_message = (
+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
+ "should be set explicitly using the `set_attn_processor` method."
+ )
+ deprecate("npu_processor", "0.34.0", deprecation_message)
+ processor = FluxAttnProcessor2_0_NPU()
+ else:
+ processor = FluxAttnProcessor()
+
+ self.attn = FluxAttention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
+
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=attention_mask,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class ChromaTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ qk_norm: str = "rms_norm",
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
+ self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
+
+ self.attn = FluxAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=FluxAttnProcessor(),
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ temb_img, temb_txt = temb[:, :6], temb[:, 6:]
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb_txt
+ )
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=attention_mask,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class ChromaTransformer2DModel(
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ FluxTransformer2DLoadersMixin,
+ CacheMixin,
+ AttentionMixin,
+):
+ """
+ The Transformer model introduced in Flux, modified for Chroma.
+
+ Reference: https://huggingface.co/lodestones/Chroma1-HD
+
+ Args:
+ patch_size (`int`, defaults to `1`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `64`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `19`):
+ The number of layers of dual stream DiT blocks to use.
+ num_single_layers (`int`, defaults to `38`):
+ The number of layers of single stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of attention heads to use.
+ joint_attention_dim (`int`, defaults to `4096`):
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
+ `encoder_hidden_states`).
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions to use for the rotary positional embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
+ _repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ out_channels: Optional[int] = None,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
+ approximator_num_channels: int = 64,
+ approximator_hidden_dim: int = 5120,
+ approximator_layers: int = 5,
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
+
+ self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings(
+ num_channels=approximator_num_channels // 4,
+ out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2,
+ )
+ self.distilled_guidance_layer = ChromaApproximator(
+ in_dim=approximator_num_channels,
+ out_dim=self.inner_dim,
+ hidden_dim=approximator_hidden_dim,
+ n_layers=approximator_layers,
+ )
+
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ ChromaTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ ChromaSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ self.norm_out = ChromaAdaLayerNormContinuousPruned(
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
+ )
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ attention_mask: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ return_dict: bool = True,
+ controlnet_blocks_repeat: bool = False,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+
+ input_vec = self.time_text_embed(timestep)
+ pooled_temb = self.distilled_guidance_layer(input_vec)
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ img_offset = 3 * len(self.single_transformer_blocks)
+ txt_offset = img_offset + 6 * len(self.transformer_blocks)
+ img_modulation = img_offset + 6 * index_block
+ text_modulation = txt_offset + 6 * index_block
+ temb = torch.cat(
+ (
+ pooled_temb[:, img_modulation : img_modulation + 6],
+ pooled_temb[:, text_modulation : text_modulation + 6],
+ ),
+ dim=1,
+ )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=attention_mask,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ # For Xlabs ControlNet.
+ if controlnet_blocks_repeat:
+ hidden_states = (
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
+ )
+ else:
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ start_idx = 3 * index_block
+ temb = pooled_temb[:, start_idx : start_idx + 3]
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ attention_mask=attention_mask,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ + controlnet_single_block_samples[index_block // interval_control]
+ )
+
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+ temb = pooled_temb[:, -2:]
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_chronoedit.py b/src/diffusers/models/transformers/transformer_chronoedit.py
new file mode 100644
index 000000000000..79828b6464f4
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_chronoedit.py
@@ -0,0 +1,739 @@
+# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
+def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
+def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+# modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor
+class WanAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: "WanAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ # Reference: https://github.com/huggingface/diffusers/pull/12660
+ parallel_config=None,
+ )
+ hidden_states_img = hidden_states_img.flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ # Reference: https://github.com/huggingface/diffusers/pull/12660
+ parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor2_0
+class WanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
+ "Please use WanAttnProcessor instead. "
+ )
+ deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+ return WanAttnProcessor(*args, **kwargs)
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttention
+class WanAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = WanAttnProcessor
+ _available_processors = [WanAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ is_cross_attention=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding
+class WanImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
+class WanTimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ pos_embed_seq_len: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ timestep_seq_len: Optional[int] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+ if timestep_seq_len is not None:
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+class ChronoEditRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ temporal_skip_len: int = 8,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+ self.temporal_skip_len = temporal_skip_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ if num_frames == 2:
+ freqs_cos_f = freqs_cos[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ else:
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ if num_frames == 2:
+ freqs_sin_f = freqs_sin[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ else:
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
+# Copied from diffusers.models.transformers.transformer_wan.WanTransformerBlock
+class WanTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=WanAttnProcessor(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ cross_attention_dim_head=dim // num_heads,
+ processor=WanAttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ if temb.ndim == 4:
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(0) + temb.float()
+ ).chunk(6, dim=2)
+ # batch_size, seq_len, 1, inner_dim
+ shift_msa = shift_msa.squeeze(2)
+ scale_msa = scale_msa.squeeze(2)
+ gate_msa = gate_msa.squeeze(2)
+ c_shift_msa = c_shift_msa.squeeze(2)
+ c_scale_msa = c_scale_msa.squeeze(2)
+ c_gate_msa = c_gate_msa.squeeze(2)
+ else:
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+# modified from diffusers.models.transformers.transformer_wan.WanTransformer3DModel
+class ChronoEditTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the ChronoEdit model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ add_img_emb (`bool`, defaults to `False`):
+ Whether to use img_emb.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["WanTransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["WanTransformerBlock"]
+ _cp_plan = {
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ },
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ # Reference: https://github.com/huggingface/diffusers/pull/12660
+ # We need to disable the splitting of encoder_hidden_states because
+ # the image_encoder consistently generates 257 tokens for image_embed. This causes
+ # the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
+ # after concatenation—to be indivisible by the number of devices in the CP.
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
+ rope_temporal_skip_len: int = 8,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.rope = ChronoEditRotaryPosEmbed(
+ attention_head_dim, patch_size, rope_max_seq_len, temporal_skip_len=rope_temporal_skip_len
+ )
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ # image_embedding_dim=1280 for I2V model
+ self.condition_embedder = WanTimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ WanTransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
+ if timestep.ndim == 2:
+ ts_seq_len = timestep.shape[1]
+ timestep = timestep.flatten() # batch_size * seq_len
+ else:
+ ts_seq_len = None
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
+ )
+ if ts_seq_len is not None:
+ # batch_size, seq_len, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+ else:
+ # batch_size, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+
+ # 5. Output norm, projection & unpatchify
+ if temb.ndim == 3:
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
+ shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift = shift.squeeze(2)
+ scale = scale.squeeze(2)
+ else:
+ # batch_size, inner_dim
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index da7133791f37..e48290fb39d4 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,24 +13,19 @@
# limitations under the License.
-from typing import Dict, Union
+from typing import Tuple, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
-from ...models.attention import FeedForward
-from ...models.attention_processor import (
- Attention,
- AttentionProcessor,
- CogVideoXAttnProcessor2_0,
-)
-from ...models.modeling_utils import ModelMixin
-from ...models.normalization import AdaLayerNormContinuous
from ...utils import logging
+from ..attention import AttentionMixin, FeedForward
+from ..attention_processor import Attention, CogVideoXAttnProcessor2_0
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
-from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -84,7 +79,7 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -130,7 +125,7 @@ def forward(
return hidden_states, encoder_hidden_states
-class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
+class CogView3PlusTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin):
r"""
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
Diffusion](https://huggingface.co/papers/2403.05121).
@@ -229,66 +224,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -298,7 +233,7 @@ def forward(
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 41c4cbbf97c7..64e9a538a7c2 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -21,13 +21,14 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
-from ..normalization import AdaLayerNormContinuous
+from ..normalization import LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -73,8 +74,9 @@ def __init__(self, embedding_dim: int, dim: int) -> None:
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
- norm_hidden_states = self.norm(hidden_states)
- norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
+ dtype = hidden_states.dtype
+ norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
+ norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
emb = self.linear(temb)
(
@@ -111,8 +113,11 @@ def forward(
class CogView4AttnProcessor:
"""
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
+
+ The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
+ text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
"""
def __init__(self):
@@ -125,8 +130,10 @@ def __call__(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ dtype = encoder_hidden_states.dtype
+
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -142,9 +149,9 @@ def __call__(
# 2. QK normalization
if attn.norm_q is not None:
- query = attn.norm_q(query)
+ query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
- key = attn.norm_k(key)
+ key = attn.norm_k(key).to(dtype=dtype)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
@@ -159,13 +166,14 @@ def __call__(
# 4. Attention
if attention_mask is not None:
- text_attention_mask = attention_mask.float().to(query.device)
- actual_text_seq_length = text_attention_mask.size(1)
- new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
- new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
- new_attention_mask = new_attention_mask.unsqueeze(2)
- attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
- attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
+ text_attn_mask = attention_mask
+ assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
+ text_attn_mask = text_attn_mask.float().to(query.device)
+ mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
+ mix_attn_mask[:, :text_seq_length] = text_attn_mask
+ mix_attn_mask = mix_attn_mask.unsqueeze(2)
+ attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
+ attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
@@ -183,9 +191,277 @@ def __call__(
return hidden_states, encoder_hidden_states
+class CogView4TrainingAttnProcessor:
+ """
+ Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
+ embedding on query and key vectors, but does not include spatial normalization.
+
+ This processor differs from CogView4AttnProcessor in several important ways:
+ 1. It supports attention masking with variable sequence lengths for multi-resolution training
+ 2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
+ provided
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ latent_attn_mask: Optional[torch.Tensor] = None,
+ text_attn_mask: Optional[torch.Tensor] = None,
+ batch_flag: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
+ ] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ attn (`Attention`):
+ The attention module.
+ hidden_states (`torch.Tensor`):
+ The input hidden states.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states for cross-attention.
+ latent_attn_mask (`torch.Tensor`, *optional*):
+ Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
+ attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
+ num_latent_tokens).
+ text_attn_mask (`torch.Tensor`, *optional*):
+ Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
+ is used for all text tokens.
+ batch_flag (`torch.Tensor`, *optional*):
+ Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
+ batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
+ batch1, and samples 3-4 form batch2. If None, no packing is used.
+ image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
+ The rotary embedding for the image part of the input.
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
+ """
+
+ # Get dimensions and device info
+ batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
+ batch_size, image_seq_length, embed_dim = hidden_states.shape
+ dtype = encoder_hidden_states.dtype
+ device = encoder_hidden_states.device
+ latent_hidden_states = hidden_states
+ # Combine text and image streams for joint processing
+ mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
+
+ # 1. Construct attention mask and maybe packing input
+ # Create default masks if not provided
+ if text_attn_mask is None:
+ text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
+ if latent_attn_mask is None:
+ latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
+
+ # Validate mask shapes and types
+ assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
+ assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
+ assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
+ assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
+
+ # Create combined mask for text and image tokens
+ mixed_attn_mask = torch.ones(
+ (batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
+ )
+ mixed_attn_mask[:, :text_seq_length] = text_attn_mask
+ mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
+
+ # Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
+ mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
+ attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
+
+ # Handle batch packing if enabled
+ if batch_flag is not None:
+ assert batch_flag.dim() == 1
+ # Determine packed batch size based on batch_flag
+ packing_batch_size = torch.max(batch_flag).item() + 1
+
+ # Calculate actual sequence lengths for each sample based on masks
+ text_seq_length = torch.sum(text_attn_mask, dim=1)
+ latent_seq_length = torch.sum(latent_attn_mask, dim=1)
+ mixed_seq_length = text_seq_length + latent_seq_length
+
+ # Calculate packed sequence lengths for each packed batch
+ mixed_seq_length_packed = [
+ torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
+ ]
+
+ assert len(mixed_seq_length_packed) == packing_batch_size
+
+ # Pack sequences by removing padding tokens
+ mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
+ mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
+ mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
+ assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
+
+ # Split the unpadded sequence into packed batches
+ mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
+
+ # Re-pad to create packed batches with right-side padding
+ mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
+ mixed_hidden_states_packed,
+ batch_first=True,
+ padding_value=0.0,
+ padding_side="right",
+ )
+
+ # Create attention mask for packed batches
+ l = mixed_hidden_states_packed_padded.shape[1]
+ attn_mask_matrix = torch.zeros(
+ (packing_batch_size, l, l),
+ dtype=dtype,
+ device=device,
+ )
+
+ # Fill attention mask with block diagonal matrices
+ # This ensures that tokens can only attend to other tokens within the same original sample
+ for idx, mask in enumerate(attn_mask_matrix):
+ seq_lengths = mixed_seq_length[batch_flag == idx]
+ offset = 0
+ for length in seq_lengths:
+ # Create a block of 1s for each sample in the packed batch
+ mask[offset : offset + length, offset : offset + length] = 1
+ offset += length
+
+ attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
+ attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
+ attention_mask = attn_mask_matrix
+
+ # Prepare hidden states for attention computation
+ if batch_flag is None:
+ # If no packing, just combine text and image tokens
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ else:
+ # If packing, use the packed sequence
+ hidden_states = mixed_hidden_states_packed_padded
+
+ # 2. QKV projections - convert hidden states to query, key, value
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ # 3. QK normalization - apply layer norm to queries and keys if configured
+ if attn.norm_q is not None:
+ query = attn.norm_q(query).to(dtype=dtype)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key).to(dtype=dtype)
+
+ # 4. Apply rotary positional embeddings to image tokens only
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ if batch_flag is None:
+ # Apply RoPE only to image tokens (after text tokens)
+ query[:, :, text_seq_length:, :] = apply_rotary_emb(
+ query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
+ )
+ key[:, :, text_seq_length:, :] = apply_rotary_emb(
+ key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
+ )
+ else:
+ # For packed batches, need to carefully apply RoPE to appropriate tokens
+ assert query.shape[0] == packing_batch_size
+ assert key.shape[0] == packing_batch_size
+ assert len(image_rotary_emb) == batch_size
+
+ rope_idx = 0
+ for idx in range(packing_batch_size):
+ offset = 0
+ # Get text and image sequence lengths for samples in this packed batch
+ text_seq_length_bi = text_seq_length[batch_flag == idx]
+ latent_seq_length_bi = latent_seq_length[batch_flag == idx]
+
+ # Apply RoPE to each image segment in the packed sequence
+ for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
+ mlen = tlen + llen
+ # Apply RoPE only to image tokens (after text tokens)
+ query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
+ query[idx, :, offset + tlen : offset + mlen, :],
+ image_rotary_emb[rope_idx],
+ use_real_unbind_dim=-2,
+ )
+ key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
+ key[idx, :, offset + tlen : offset + mlen, :],
+ image_rotary_emb[rope_idx],
+ use_real_unbind_dim=-2,
+ )
+ offset += mlen
+ rope_idx += 1
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ # Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ # 5. Output projection - project attention output to model dimension
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ # Split the output back into text and image streams
+ if batch_flag is None:
+ # Simple split for non-packed case
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ else:
+ # For packed case: need to unpack, split text/image, then restore to original shapes
+ # First, unpad the sequence based on the packed sequence lengths
+ hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
+ hidden_states,
+ lengths=torch.tensor(mixed_seq_length_packed),
+ batch_first=True,
+ )
+ # Concatenate all unpadded sequences
+ hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
+ # Split by original sample sequence lengths
+ hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
+ assert len(hidden_states_unpack) == batch_size
+
+ # Further split each sample's sequence into text and image parts
+ hidden_states_unpack = [
+ torch.split(h, [tlen, llen])
+ for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
+ ]
+ # Separate text and image sequences
+ encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
+ hidden_states_unpad = [h[1] for h in hidden_states_unpack]
+
+ # Update the original tensors with the processed values, respecting the attention masks
+ for idx in range(batch_size):
+ # Place unpacked text tokens back in the encoder_hidden_states tensor
+ encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
+ # Place unpacked image tokens back in the latent_hidden_states tensor
+ latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
+
+ # Update the output hidden states
+ hidden_states = latent_hidden_states
+
+ return hidden_states, encoder_hidden_states
+
+
+@maybe_allow_in_graph
class CogView4TransformerBlock(nn.Module):
def __init__(
- self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
+ self,
+ dim: int = 2560,
+ num_attention_heads: int = 64,
+ attention_head_dim: int = 40,
+ time_embed_dim: int = 512,
) -> None:
super().__init__()
@@ -213,10 +489,12 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> torch.Tensor:
+ image_rotary_emb: Optional[
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
+ ] = None,
+ attention_mask: Optional[Dict[str, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
@@ -232,12 +510,14 @@ def forward(
) = self.norm1(hidden_states, encoder_hidden_states, temb)
# 2. Attention
+ if attention_kwargs is None:
+ attention_kwargs = {}
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
- **kwargs,
+ **attention_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -304,6 +584,38 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
return (freqs.cos(), freqs.sin())
+class CogView4AdaLayerNormContinuous(nn.Module):
+ """
+ CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
+ Linear on conditioning embedding.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ elementwise_affine: bool = True,
+ eps: float = 1e-5,
+ bias: bool = True,
+ norm_type: str = "layer_norm",
+ ):
+ super().__init__()
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ # *** NO SiLU here ***
+ emb = self.linear(conditioning_embedding.to(x.dtype))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
@@ -386,7 +698,7 @@ def __init__(
)
# 4. Output projection
- self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
+ self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
self.gradient_checkpointing = False
@@ -402,8 +714,10 @@ def forward(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
- **kwargs,
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ image_rotary_emb: Optional[
+ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
+ ] = None,
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -422,7 +736,8 @@ def forward(
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
- image_rotary_emb = self.rope(hidden_states)
+ if image_rotary_emb is None:
+ image_rotary_emb = self.rope(hidden_states)
# 2. Patch & Timestep embeddings
p = self.config.patch_size
@@ -438,11 +753,22 @@ def forward(
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
- block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ attention_mask,
+ attention_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
- hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ attention_mask,
+ attention_kwargs,
)
# 4. Output norm & projection
diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py
new file mode 100644
index 000000000000..373b470ae37b
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_cosmos.py
@@ -0,0 +1,586 @@
+# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import is_torchvision_available
+from ..attention import FeedForward
+from ..attention_processor import Attention
+from ..embeddings import Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import RMSNorm
+
+
+if is_torchvision_available():
+ from torchvision import transforms
+
+
+class CosmosPatchEmbed(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+
+ self.proj = nn.Linear(in_channels * patch_size[0] * patch_size[1] * patch_size[2], out_channels, bias=bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ hidden_states = hidden_states.reshape(
+ batch_size, num_channels, num_frames // p_t, p_t, height // p_h, p_h, width // p_w, p_w
+ )
+ hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7)
+ hidden_states = self.proj(hidden_states)
+ return hidden_states
+
+
+class CosmosTimestepEmbedding(nn.Module):
+ def __init__(self, in_features: int, out_features: int) -> None:
+ super().__init__()
+ self.linear_1 = nn.Linear(in_features, out_features, bias=False)
+ self.activation = nn.SiLU()
+ self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
+
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
+ emb = self.linear_1(timesteps)
+ emb = self.activation(emb)
+ emb = self.linear_2(emb)
+ return emb
+
+
+class CosmosEmbedding(nn.Module):
+ def __init__(self, embedding_dim: int, condition_dim: int) -> None:
+ super().__init__()
+
+ self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
+ self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim)
+ self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True)
+
+ def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep).type_as(hidden_states)
+ temb = self.t_embedder(timesteps_proj)
+ embedded_timestep = self.norm(timesteps_proj)
+ return temb, embedded_timestep
+
+
+class CosmosAdaLayerNorm(nn.Module):
+ def __init__(self, in_features: int, hidden_features: int) -> None:
+ super().__init__()
+ self.embedding_dim = in_features
+
+ self.activation = nn.SiLU()
+ self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
+ self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
+ self.linear_2 = nn.Linear(hidden_features, 2 * in_features, bias=False)
+
+ def forward(
+ self, hidden_states: torch.Tensor, embedded_timestep: torch.Tensor, temb: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ embedded_timestep = self.activation(embedded_timestep)
+ embedded_timestep = self.linear_1(embedded_timestep)
+ embedded_timestep = self.linear_2(embedded_timestep)
+
+ if temb is not None:
+ embedded_timestep = embedded_timestep + temb[..., : 2 * self.embedding_dim]
+
+ shift, scale = embedded_timestep.chunk(2, dim=-1)
+ hidden_states = self.norm(hidden_states)
+
+ if embedded_timestep.ndim == 2:
+ shift, scale = (x.unsqueeze(1) for x in (shift, scale))
+
+ hidden_states = hidden_states * (1 + scale) + shift
+ return hidden_states
+
+
+class CosmosAdaLayerNormZero(nn.Module):
+ def __init__(self, in_features: int, hidden_features: Optional[int] = None) -> None:
+ super().__init__()
+
+ self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
+ self.activation = nn.SiLU()
+
+ if hidden_features is None:
+ self.linear_1 = nn.Identity()
+ else:
+ self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
+
+ self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ embedded_timestep: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ embedded_timestep = self.activation(embedded_timestep)
+ embedded_timestep = self.linear_1(embedded_timestep)
+ embedded_timestep = self.linear_2(embedded_timestep)
+
+ if temb is not None:
+ embedded_timestep = embedded_timestep + temb
+
+ shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
+ hidden_states = self.norm(hidden_states)
+
+ if embedded_timestep.ndim == 2:
+ shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate))
+
+ hidden_states = hidden_states * (1 + scale) + shift
+ return hidden_states, gate
+
+
+class CosmosAttnProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # 1. QKV projections
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ # 2. QK normalization
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ # 3. Apply RoPE
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
+
+ # 4. Prepare for GQA
+ if torch.onnx.is_in_onnx_export():
+ query_idx = torch.tensor(query.size(3), device=query.device)
+ key_idx = torch.tensor(key.size(3), device=key.device)
+ value_idx = torch.tensor(value.size(3), device=value.device)
+
+ else:
+ query_idx = query.size(3)
+ key_idx = key.size(3)
+ value_idx = value.size(3)
+ key = key.repeat_interleave(query_idx // key_idx, dim=3)
+ value = value.repeat_interleave(query_idx // value_idx, dim=3)
+
+ # 5. Attention
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
+
+ # 6. Output projection
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class CosmosTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ cross_attention_dim: int,
+ mlp_ratio: float = 4.0,
+ adaln_lora_dim: int = 256,
+ qk_norm: str = "rms_norm",
+ out_bias: bool = False,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
+ self.attn1 = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ qk_norm=qk_norm,
+ elementwise_affine=True,
+ out_bias=out_bias,
+ processor=CosmosAttnProcessor2_0(),
+ )
+
+ self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
+ self.attn2 = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ qk_norm=qk_norm,
+ elementwise_affine=True,
+ out_bias=out_bias,
+ processor=CosmosAttnProcessor2_0(),
+ )
+
+ self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ embedded_timestep: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ extra_pos_emb: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if extra_pos_emb is not None:
+ hidden_states = hidden_states + extra_pos_emb
+
+ # 1. Self Attention
+ norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
+ attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
+ hidden_states = hidden_states + gate * attn_output
+
+ # 2. Cross Attention
+ norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
+ attn_output = self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ hidden_states = hidden_states + gate * attn_output
+
+ # 3. Feed Forward
+ norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + gate * ff_output
+
+ return hidden_states
+
+
+class CosmosRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ max_size: Tuple[int, int, int] = (128, 240, 240),
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ base_fps: int = 24,
+ rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
+ ) -> None:
+ super().__init__()
+
+ self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
+ self.patch_size = patch_size
+ self.base_fps = base_fps
+
+ self.dim_h = hidden_size // 6 * 2
+ self.dim_w = hidden_size // 6 * 2
+ self.dim_t = hidden_size - self.dim_h - self.dim_w
+
+ self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
+ self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
+ self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))
+
+ def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
+ device = hidden_states.device
+
+ h_theta = 10000.0 * self.h_ntk_factor
+ w_theta = 10000.0 * self.w_ntk_factor
+ t_theta = 10000.0 * self.t_ntk_factor
+
+ seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
+ dim_h_range = (
+ torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
+ )
+ dim_w_range = (
+ torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
+ )
+ dim_t_range = (
+ torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
+ )
+ h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
+ w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
+ temporal_freqs = 1.0 / (t_theta**dim_t_range)
+
+ emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].repeat(pe_size[0], 1, pe_size[2], 1)
+ emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].repeat(pe_size[0], pe_size[1], 1, 1)
+
+ # Apply sequence scaling in temporal dimension
+ if fps is None:
+ # Images
+ emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
+ else:
+ # Videos
+ emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)
+
+ emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1)
+ freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
+ cos = torch.cos(freqs)
+ sin = torch.sin(freqs)
+ return cos, sin
+
+
+class CosmosLearnablePositionalEmbed(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ max_size: Tuple[int, int, int],
+ patch_size: Tuple[int, int, int],
+ eps: float = 1e-6,
+ ) -> None:
+ super().__init__()
+
+ self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
+ self.patch_size = patch_size
+ self.eps = eps
+
+ self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
+ self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
+ self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
+
+ emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
+ emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
+ emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
+ emb = emb_t + emb_h + emb_w
+ emb = emb.flatten(1, 3)
+
+ norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
+ norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
+ return (emb / norm).type_as(hidden_states)
+
+
+class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `32`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each attention head.
+ num_layers (`int`, defaults to `28`):
+ The number of layers of transformer blocks to use.
+ mlp_ratio (`float`, defaults to `4.0`):
+ The ratio of the hidden layer size to the input size in the feedforward network.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ adaln_lora_dim (`int`, defaults to `256`):
+ The hidden dimension of the Adaptive LayerNorm LoRA layer.
+ max_size (`Tuple[int, int, int]`, defaults to `(128, 240, 240)`):
+ The maximum size of the input latent tensors in the temporal, height, and width dimensions.
+ patch_size (`Tuple[int, int, int]`, defaults to `(1, 2, 2)`):
+ The patch size to use for patchifying the input latent tensors in the temporal, height, and width
+ dimensions.
+ rope_scale (`Tuple[float, float, float]`, defaults to `(2.0, 1.0, 1.0)`):
+ The scaling factor to use for RoPE in the temporal, height, and width dimensions.
+ concat_padding_mask (`bool`, defaults to `True`):
+ Whether to concatenate the padding mask to the input latent tensors.
+ extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
+ The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
+ _no_split_modules = ["CosmosTransformerBlock"]
+ _keep_in_fp32_modules = ["learnable_pos_embed"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ num_attention_heads: int = 32,
+ attention_head_dim: int = 128,
+ num_layers: int = 28,
+ mlp_ratio: float = 4.0,
+ text_embed_dim: int = 1024,
+ adaln_lora_dim: int = 256,
+ max_size: Tuple[int, int, int] = (128, 240, 240),
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
+ concat_padding_mask: bool = True,
+ extra_pos_embed_type: Optional[str] = "learnable",
+ ) -> None:
+ super().__init__()
+ hidden_size = num_attention_heads * attention_head_dim
+
+ # 1. Patch Embedding
+ patch_embed_in_channels = in_channels + 1 if concat_padding_mask else in_channels
+ self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, patch_size, bias=False)
+
+ # 2. Positional Embedding
+ self.rope = CosmosRotaryPosEmbed(
+ hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
+ )
+
+ self.learnable_pos_embed = None
+ if extra_pos_embed_type == "learnable":
+ self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
+ hidden_size=hidden_size,
+ max_size=max_size,
+ patch_size=patch_size,
+ )
+
+ # 3. Time Embedding
+ self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
+
+ # 4. Transformer Blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CosmosTransformerBlock(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ cross_attention_dim=text_embed_dim,
+ mlp_ratio=mlp_ratio,
+ adaln_lora_dim=adaln_lora_dim,
+ qk_norm="rms_norm",
+ out_bias=False,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 5. Output norm & projection
+ self.norm_out = CosmosAdaLayerNorm(hidden_size, adaln_lora_dim)
+ self.proj_out = nn.Linear(
+ hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ fps: Optional[int] = None,
+ condition_mask: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ # 1. Concatenate padding mask if needed & prepare attention mask
+ if condition_mask is not None:
+ hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
+
+ if self.config.concat_padding_mask:
+ padding_mask = transforms.functional.resize(
+ padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
+ )
+ hidden_states = torch.cat(
+ [hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
+
+ # 2. Generate positional embeddings
+ image_rotary_emb = self.rope(hidden_states, fps=fps)
+ extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
+
+ # 3. Patchify input
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+ hidden_states = self.patch_embed(hidden_states)
+ hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
+
+ # 4. Timestep embeddings
+ if timestep.ndim == 1:
+ temb, embedded_timestep = self.time_embed(hidden_states, timestep)
+ elif timestep.ndim == 5:
+ assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
+ f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
+ )
+ timestep = timestep.flatten()
+ temb, embedded_timestep = self.time_embed(hidden_states, timestep)
+ # We can do this because num_frames == post_patch_num_frames, as p_t is 1
+ temb, embedded_timestep = (
+ x.view(batch_size, post_patch_num_frames, 1, 1, -1)
+ .expand(-1, -1, post_patch_height, post_patch_width, -1)
+ .flatten(1, 3)
+ for x in (temb, embedded_timestep)
+ ) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
+ else:
+ assert False
+
+ # 5. Transformer blocks
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ embedded_timestep,
+ temb,
+ image_rotary_emb,
+ extra_pos_emb,
+ attention_mask,
+ )
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ embedded_timestep=embedded_timestep,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ extra_pos_emb=extra_pos_emb,
+ attention_mask=attention_mask,
+ )
+
+ # 6. Output norm & projection & unpatchify
+ hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
+ hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
+ # NOTE: The permutation order here is not the inverse operation of what happens when patching as usually expected.
+ # It might be a source of confusion to the reader, but this is correct
+ hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(sample=hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index b64920c374f4..f5955775b991 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,55 +12,345 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from typing import Any, Dict, Optional, Tuple, Union
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
-from ...loaders import (
- FluxTransformer2DLoadersMixin,
- FromOriginalModelMixin,
- PeftAdapterMixin,
-)
-from ...models.attention import FeedForward
-from ...models.attention_processor import (
- Attention,
- AttentionProcessor,
- FluxAttnProcessor2_0,
- FluxAttnProcessor2_0_NPU,
- FusedFluxAttnProcessor2_0,
-)
-from ...models.modeling_utils import ModelMixin
-from ...models.normalization import (
- AdaLayerNormContinuous,
- AdaLayerNormZero,
- AdaLayerNormZeroSingle,
-)
-from ...utils import (
- USE_PEFT_BACKEND,
- deprecate,
- logging,
- scale_lora_layers,
- unscale_lora_layers,
-)
-from ...utils.import_utils import is_torch_npu_available
+from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
- FluxPosEmbed,
- TimestepEmbedding,
- Timesteps,
+ apply_rotary_emb,
+ get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+class FluxAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FluxIPAdapterAttnProcessor(torch.nn.Module):
+ """Flux Attention processor for IP-Adapter."""
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
+ ):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+
+ def __call__(
+ self,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+ ip_query = query
+
+ if encoder_hidden_states is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # IP-adapter
+ ip_attn_output = torch.zeros_like(hidden_states)
+
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ ):
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
+
+ current_ip_hidden_states = dispatch_attention_fn(
+ ip_query,
+ ip_key,
+ ip_value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
+ current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
+ ip_attn_output += scale * current_ip_hidden_states
+
+ return hidden_states, encoder_hidden_states, ip_attn_output
+ else:
+ return hidden_states
+
+
+class FluxAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = FluxAttnProcessor
+ _available_processors = [
+ FluxAttnProcessor,
+ FluxIPAdapterAttnProcessor,
+ ]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ context_pre_only: Optional[bool] = None,
+ pre_only: bool = False,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(
@@ -78,25 +368,13 @@ def __init__(
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
- if is_torch_npu_available():
- deprecation_message = (
- "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
- "should be set explicitly using the `set_attn_processor` method."
- )
- deprecate("npu_processor", "0.34.0", deprecation_message)
- processor = FluxAttnProcessor2_0_NPU()
- else:
- processor = FluxAttnProcessor2_0()
-
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
- processor=processor,
- qk_norm="rms_norm",
+ processor=FluxAttnProcessor(),
eps=1e-6,
pre_only=True,
)
@@ -104,11 +382,14 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -127,7 +408,8 @@ def forward(
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
- return hidden_states
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@@ -145,17 +427,15 @@ def __init__(
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
- processor=FluxAttnProcessor2_0(),
- qk_norm=qk_norm,
+ processor=FluxAttnProcessor(),
eps=eps,
)
@@ -184,6 +464,7 @@ def forward(
self.norm1_context(encoder_hidden_states, emb=temb)
)
joint_attention_kwargs = joint_attention_kwargs or {}
+
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
@@ -215,7 +496,6 @@ def forward(
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
-
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
@@ -235,6 +515,37 @@ def forward(
return encoder_hidden_states, hidden_states
+class FluxPosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
class FluxTransformer2DModel(
ModelMixin,
ConfigMixin,
@@ -242,6 +553,7 @@ class FluxTransformer2DModel(
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
+ AttentionMixin,
):
"""
The Transformer model introduced in Flux.
@@ -277,6 +589,16 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+ _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
@@ -291,8 +613,7 @@ def __init__(
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
- additional_timestep_embeds: bool = False,
- axes_dims_rope: Tuple[int] = (16, 56, 56),
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -352,114 +673,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(
- name: str,
- module: torch.nn.Module,
- processors: Dict[str, AttentionProcessor],
- ):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
- ):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError(
- "`fuse_qkv_projections()` is not supported for models having added KV projections."
- )
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -526,8 +739,6 @@ def forward(
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
- else:
- guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
@@ -560,7 +771,11 @@ def forward(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
- image_rotary_emb = self.pos_embed(ids)
+ if is_torch_npu_available():
+ freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
+ image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
+ else:
+ image_rotary_emb = self.pos_embed(ids)
if (
joint_attention_kwargs is not None
@@ -574,15 +789,13 @@ def forward(
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- encoder_hidden_states, hidden_states = (
- self._gradient_checkpointing_func(
- block,
- hidden_states,
- encoder_hidden_states,
- temb,
- image_rotary_emb,
- attention_mask=attention_mask,
- )
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ joint_attention_kwargs,
)
else:
@@ -610,25 +823,23 @@ def forward(
]
)
else:
- hidden_states = (
- hidden_states
- + controlnet_block_samples[index_block // interval_control]
- )
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
+ encoder_hidden_states,
temb,
image_rotary_emb,
- attention_mask=attention_mask,
+ joint_attention_kwargs,
)
else:
- hidden_states = block(
+ encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -641,12 +852,7 @@ def forward(
controlnet_single_block_samples
)
interval_control = int(np.ceil(interval_control))
- hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
- hidden_states[:, encoder_hidden_states.shape[1] :, ...]
- + controlnet_single_block_samples[index_block // interval_control]
- )
-
- hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py
new file mode 100644
index 000000000000..c10bf3ed4f7b
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_flux2.py
@@ -0,0 +1,908 @@
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ TimestepEmbedding,
+ Timesteps,
+ apply_rotary_emb,
+ get_1d_rotary_pos_embed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+class Flux2SwiGLU(nn.Module):
+ """
+ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
+ layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.gate_fn = nn.SiLU()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x1, x2 = x.chunk(2, dim=-1)
+ x = self.gate_fn(x1) * x2
+ return x
+
+
+class Flux2FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: float = 3.0,
+ inner_dim: Optional[int] = None,
+ bias: bool = False,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out or dim
+
+ # Flux2SwiGLU will reduce the dimension by half
+ self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
+ self.act_fn = Flux2SwiGLU()
+ self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear_in(x)
+ x = self.act_fn(x)
+ x = self.linear_out(x)
+ return x
+
+
+class Flux2AttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "Flux2Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = Flux2AttnProcessor
+ _available_processors = [Flux2AttnProcessor]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+
+ self.use_bias = bias
+ self.dropout = dropout
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ # QK Norm
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class Flux2ParallelSelfAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "Flux2ParallelSelfAttention",
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # Parallel in (QKV + MLP in) projection
+ hidden_states = attn.to_qkv_mlp_proj(hidden_states)
+ qkv, mlp_hidden_states = torch.split(
+ hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
+ )
+
+ # Handle the attention logic
+ query, key, value = qkv.chunk(3, dim=-1)
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Handle the feedforward (FF) logic
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
+
+ # Concatenate and parallel output projection
+ hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
+ hidden_states = attn.to_out(hidden_states)
+
+ return hidden_states
+
+
+class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
+ """
+ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
+
+ This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
+ input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
+ paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
+ """
+
+ _default_processor_cls = Flux2ParallelSelfAttnProcessor
+ _available_processors = [Flux2ParallelSelfAttnProcessor]
+ # Does not support QKV fusion as the QKV projections are always fused
+ _supports_qkv_fusion = False
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ elementwise_affine: bool = True,
+ mlp_ratio: float = 4.0,
+ mlp_mult_factor: int = 2,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+
+ self.use_bias = bias
+ self.dropout = dropout
+
+ self.mlp_ratio = mlp_ratio
+ self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
+ self.mlp_mult_factor = mlp_mult_factor
+
+ # Fused QKV projections + MLP input projection
+ self.to_qkv_mlp_proj = torch.nn.Linear(
+ self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
+ )
+ self.mlp_act_fn = Flux2SwiGLU()
+
+ # QK Norm
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+
+ # Fused attention output projection + MLP output projection
+ self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class Flux2SingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 3.0,
+ eps: float = 1e-6,
+ bias: bool = False,
+ ):
+ super().__init__()
+
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+
+ # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
+ # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
+ # for a visual depiction of this type of transformer block.
+ self.attn = Flux2ParallelSelfAttention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=bias,
+ out_bias=bias,
+ eps=eps,
+ mlp_ratio=mlp_ratio,
+ mlp_mult_factor=2,
+ processor=Flux2ParallelSelfAttnProcessor(),
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor],
+ temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ split_hidden_states: bool = False,
+ text_seq_len: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
+ # concatenated
+ if encoder_hidden_states is not None:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ mod_shift, mod_scale, mod_gate = temb_mod_params
+
+ norm_hidden_states = self.norm(hidden_states)
+ norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = hidden_states + mod_gate * attn_output
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ if split_hidden_states:
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+ else:
+ return hidden_states
+
+
+class Flux2TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 3.0,
+ eps: float = 1e-6,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+
+ self.attn = Flux2Attention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=bias,
+ added_proj_bias=bias,
+ out_bias=bias,
+ eps=eps,
+ processor=Flux2AttnProcessor(),
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
+ temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Modulation parameters shape: [1, 1, self.dim]
+ (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
+ (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
+
+ # Img stream
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
+
+ # Conditioning txt stream
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
+ norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
+
+ # Attention on concatenated img + txt stream
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ attn_output, context_attn_output = attention_outputs
+
+ # Process attention outputs for the image stream (`hidden_states`).
+ attn_output = gate_msa * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ ff_output = self.ff(norm_hidden_states)
+ hidden_states = hidden_states + gate_mlp * ff_output
+
+ # Process attention outputs for the text stream (`encoder_hidden_states`).
+ context_attn_output = c_gate_msa * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class Flux2PosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ # Expected ids shape: [S, len(self.axes_dim)]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
+ for i in range(len(self.axes_dim)):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[..., i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+class Flux2TimestepGuidanceEmbeddings(nn.Module):
+ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
+ )
+
+ self.guidance_embedder = TimestepEmbedding(
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
+ )
+
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
+
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
+
+ time_guidance_emb = timesteps_emb + guidance_emb
+
+ return time_guidance_emb
+
+
+class Flux2Modulation(nn.Module):
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
+ super().__init__()
+ self.mod_param_sets = mod_param_sets
+
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
+ mod = self.act_fn(temb)
+ mod = self.linear(mod)
+
+ if mod.ndim == 2:
+ mod = mod.unsqueeze(1)
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
+ # Return tuple of 3-tuples of modulation params shift/scale/gate
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
+
+
+class Flux2Transformer2DModel(
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ FluxTransformer2DLoadersMixin,
+ CacheMixin,
+ AttentionMixin,
+):
+ """
+ The Transformer model introduced in Flux 2.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ patch_size (`int`, defaults to `1`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `128`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `8`):
+ The number of layers of dual stream DiT blocks to use.
+ num_single_layers (`int`, defaults to `48`):
+ The number of layers of single stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `48`):
+ The number of attention heads to use.
+ joint_attention_dim (`int`, defaults to `15360`):
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
+ `encoder_hidden_states`).
+ pooled_projection_dim (`int`, defaults to `768`):
+ The number of dimensions to use for the pooled projection.
+ guidance_embeds (`bool`, defaults to `True`):
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
+ axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`):
+ The dimensions to use for the rotary positional embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+ _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 128,
+ out_channels: Optional[int] = None,
+ num_layers: int = 8,
+ num_single_layers: int = 48,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 48,
+ joint_attention_dim: int = 15360,
+ timestep_guidance_channels: int = 256,
+ mlp_ratio: float = 3.0,
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
+ rope_theta: int = 2000,
+ eps: float = 1e-6,
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
+
+ # 2. Combined timestep + guidance embedding
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
+ in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
+ )
+
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
+
+ # 4. Input projections
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
+
+ # 5. Double Stream Transformer Blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ Flux2TransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ eps=eps,
+ bias=False,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 6. Single Stream Transformer Blocks
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ Flux2SingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ eps=eps,
+ bias=False,
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ # 7. Output layers
+ self.norm_out = AdaLayerNormContinuous(
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
+ )
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # 0. Handle input arguments
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ num_txt_tokens = encoder_hidden_states.shape[1]
+
+ # 1. Calculate timestep embedding and modulation parameters
+ timestep = timestep.to(hidden_states.dtype) * 1000
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = self.time_guidance_embed(timestep, guidance)
+
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
+ single_stream_mod = self.single_stream_modulation(temb)[0]
+
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
+ hidden_states = self.x_embedder(hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ # 3. Calculate RoPE embeddings from image and text tokens
+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
+ # text prompts of differents lengths. Is this a use case we want to support?
+ if img_ids.ndim == 3:
+ img_ids = img_ids[0]
+ if txt_ids.ndim == 3:
+ txt_ids = txt_ids[0]
+
+ if is_torch_npu_available():
+ freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
+ image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
+ freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
+ text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
+ else:
+ image_rotary_emb = self.pos_embed(img_ids)
+ text_rotary_emb = self.pos_embed(txt_ids)
+ concat_rotary_emb = (
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
+ )
+
+ # 4. Double Stream Transformer Blocks
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ double_stream_mod_img,
+ double_stream_mod_txt,
+ concat_rotary_emb,
+ joint_attention_kwargs,
+ )
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb_mod_params_img=double_stream_mod_img,
+ temb_mod_params_txt=double_stream_mod_txt,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+ # Concatenate text and image streams for single-block inference
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ # 5. Single Stream Transformer Blocks
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ None,
+ single_stream_mod,
+ concat_rotary_emb,
+ joint_attention_kwargs,
+ )
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=None,
+ temb_mod_params=single_stream_mod,
+ image_rotary_emb=concat_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+ # Remove text tokens from concatenated stream
+ hidden_states = hidden_states[:, num_txt_tokens:, ...]
+
+ # 6. Output layers
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py
new file mode 100644
index 000000000000..4a5aee29abc4
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_hidream_image.py
@@ -0,0 +1,942 @@
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...models.modeling_outputs import Transformer2DModelOutput
+from ...models.modeling_utils import ModelMixin
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import Attention
+from ..embeddings import TimestepEmbedding, Timesteps
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HiDreamImageFeedForwardSwiGLU(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+
+
+class HiDreamImagePooledEmbed(nn.Module):
+ def __init__(self, text_emb_dim, hidden_size):
+ super().__init__()
+ self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size)
+
+ def forward(self, pooled_embed: torch.Tensor) -> torch.Tensor:
+ return self.pooled_embedder(pooled_embed)
+
+
+class HiDreamImageTimestepEmbed(nn.Module):
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
+
+ def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
+ t_emb = self.time_proj(timesteps).to(dtype=wdtype)
+ t_emb = self.timestep_embedder(t_emb)
+ return t_emb
+
+
+class HiDreamImageOutEmbed(nn.Module):
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
+ shift, scale = self.adaLN_modulation(temb).chunk(2, dim=1)
+ hidden_states = self.norm_final(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+ hidden_states = self.linear(hidden_states)
+ return hidden_states
+
+
+class HiDreamImagePatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size=2,
+ in_channels=4,
+ out_channels=1024,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+ self.out_channels = out_channels
+ self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
+
+ def forward(self, latent) -> torch.Tensor:
+ latent = self.proj(latent)
+ return latent
+
+
+def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ assert dim % 2 == 0, "The dimension must be even."
+
+ is_mps = pos.device.type == "mps"
+ is_npu = pos.device.type == "npu"
+
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+
+ scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+
+ batch_size, seq_length = pos.shape
+ out = torch.einsum("...n,d->...nd", pos, omega)
+ cos_out = torch.cos(out)
+ sin_out = torch.sin(out)
+
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
+ return out.float()
+
+
+class HiDreamImageEmbedND(nn.Module):
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+ return emb.unsqueeze(2)
+
+
+def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
+
+
+@maybe_allow_in_graph
+class HiDreamAttention(Attention):
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ scale_qk: bool = True,
+ eps: float = 1e-5,
+ processor=None,
+ out_dim: int = None,
+ single: bool = False,
+ ):
+ super(Attention, self).__init__()
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.out_dim = out_dim if out_dim is not None else query_dim
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.sliceable_head_dim = heads
+ self.single = single
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim)
+ self.to_k = nn.Linear(self.inner_dim, self.inner_dim)
+ self.to_v = nn.Linear(self.inner_dim, self.inner_dim)
+ self.to_out = nn.Linear(self.inner_dim, self.out_dim)
+ self.q_rms_norm = nn.RMSNorm(self.inner_dim, eps)
+ self.k_rms_norm = nn.RMSNorm(self.inner_dim, eps)
+
+ if not single:
+ self.to_q_t = nn.Linear(query_dim, self.inner_dim)
+ self.to_k_t = nn.Linear(self.inner_dim, self.inner_dim)
+ self.to_v_t = nn.Linear(self.inner_dim, self.inner_dim)
+ self.to_out_t = nn.Linear(self.inner_dim, self.out_dim)
+ self.q_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
+ self.k_rms_norm_t = nn.RMSNorm(self.inner_dim, eps)
+
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ norm_hidden_states: torch.Tensor,
+ hidden_states_masks: torch.Tensor = None,
+ norm_encoder_hidden_states: torch.Tensor = None,
+ image_rotary_emb: torch.Tensor = None,
+ ) -> torch.Tensor:
+ return self.processor(
+ self,
+ hidden_states=norm_hidden_states,
+ hidden_states_masks=hidden_states_masks,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+
+class HiDreamAttnProcessor:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __call__(
+ self,
+ attn: HiDreamAttention,
+ hidden_states: torch.Tensor,
+ hidden_states_masks: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ image_rotary_emb: torch.Tensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ dtype = hidden_states.dtype
+ batch_size = hidden_states.shape[0]
+
+ query_i = attn.q_rms_norm(attn.to_q(hidden_states)).to(dtype=dtype)
+ key_i = attn.k_rms_norm(attn.to_k(hidden_states)).to(dtype=dtype)
+ value_i = attn.to_v(hidden_states)
+
+ inner_dim = key_i.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
+ key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
+ value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
+ if hidden_states_masks is not None:
+ key_i = key_i * hidden_states_masks.view(batch_size, -1, 1, 1)
+
+ if not attn.single:
+ query_t = attn.q_rms_norm_t(attn.to_q_t(encoder_hidden_states)).to(dtype=dtype)
+ key_t = attn.k_rms_norm_t(attn.to_k_t(encoder_hidden_states)).to(dtype=dtype)
+ value_t = attn.to_v_t(encoder_hidden_states)
+
+ query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
+ key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
+ value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
+
+ num_image_tokens = query_i.shape[1]
+ num_text_tokens = query_t.shape[1]
+ query = torch.cat([query_i, query_t], dim=1)
+ key = torch.cat([key_i, key_t], dim=1)
+ value = torch.cat([value_i, value_t], dim=1)
+ else:
+ query = query_i
+ key = key_i
+ value = value_i
+
+ if query.shape[-1] == image_rotary_emb.shape[-3] * 2:
+ query, key = apply_rope(query, key, image_rotary_emb)
+
+ else:
+ query_1, query_2 = query.chunk(2, dim=-1)
+ key_1, key_2 = key.chunk(2, dim=-1)
+ query_1, key_1 = apply_rope(query_1, key_1, image_rotary_emb)
+ query = torch.cat([query_1, query_2], dim=-1)
+ key = torch.cat([key_1, key_2], dim=-1)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2), dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if not attn.single:
+ hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
+ hidden_states_i = attn.to_out(hidden_states_i)
+ hidden_states_t = attn.to_out_t(hidden_states_t)
+ return hidden_states_i, hidden_states_t
+ else:
+ hidden_states = attn.to_out(hidden_states)
+ return hidden_states
+
+
+# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
+class MoEGate(nn.Module):
+ def __init__(
+ self,
+ embed_dim,
+ num_routed_experts=4,
+ num_activated_experts=2,
+ aux_loss_alpha=0.01,
+ _force_inference_output=False,
+ ):
+ super().__init__()
+ self.top_k = num_activated_experts
+ self.n_routed_experts = num_routed_experts
+
+ self.scoring_func = "softmax"
+ self.alpha = aux_loss_alpha
+ self.seq_aux = False
+
+ # topk selection algorithm
+ self.norm_topk_prob = False
+ self.gating_dim = embed_dim
+ self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
+
+ self._force_inference_output = _force_inference_output
+
+ def forward(self, hidden_states):
+ bsz, seq_len, h = hidden_states.shape
+ ### compute gating score
+ hidden_states = hidden_states.view(-1, h)
+ logits = F.linear(hidden_states, self.weight, None)
+ if self.scoring_func == "softmax":
+ scores = logits.softmax(dim=-1)
+ else:
+ raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
+
+ ### select top-k experts
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
+
+ ### norm gate to sum 1
+ if self.top_k > 1 and self.norm_topk_prob:
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weight = topk_weight / denominator
+
+ ### expert-level computation auxiliary loss
+ if self.training and self.alpha > 0.0 and not self._force_inference_output:
+ scores_for_aux = scores
+ aux_topk = self.top_k
+ # always compute aux loss based on the naive greedy topk method
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
+ if self.seq_aux:
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
+ ce.scatter_add_(
+ 1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)
+ ).div_(seq_len * aux_topk / self.n_routed_experts)
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
+ else:
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
+ ce = mask_ce.float().mean(0)
+
+ Pi = scores_for_aux.mean(0)
+ fi = ce * self.n_routed_experts
+ aux_loss = (Pi * fi).sum() * self.alpha
+ else:
+ aux_loss = None
+ return topk_idx, topk_weight, aux_loss
+
+
+# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
+class MOEFeedForwardSwiGLU(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ num_routed_experts: int,
+ num_activated_experts: int,
+ _force_inference_output: bool = False,
+ ):
+ super().__init__()
+ self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
+ self.experts = nn.ModuleList(
+ [HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
+ )
+ self._force_inference_output = _force_inference_output
+ self.gate = MoEGate(
+ embed_dim=dim,
+ num_routed_experts=num_routed_experts,
+ num_activated_experts=num_activated_experts,
+ _force_inference_output=_force_inference_output,
+ )
+ self.num_activated_experts = num_activated_experts
+
+ def forward(self, x):
+ wtype = x.dtype
+ identity = x
+ orig_shape = x.shape
+ topk_idx, topk_weight, aux_loss = self.gate(x)
+ x = x.view(-1, x.shape[-1])
+ flat_topk_idx = topk_idx.view(-1)
+ if self.training and not self._force_inference_output:
+ x = x.repeat_interleave(self.num_activated_experts, dim=0)
+ y = torch.empty_like(x, dtype=wtype)
+ for i, expert in enumerate(self.experts):
+ y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
+ y = y.view(*orig_shape).to(dtype=wtype)
+ # y = AddAuxiliaryLoss.apply(y, aux_loss)
+ else:
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
+ y = y + self.shared_experts(identity)
+ return y
+
+ @torch.no_grad()
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
+ expert_cache = torch.zeros_like(x)
+ idxs = flat_expert_indices.argsort()
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
+ token_idxs = idxs // self.num_activated_experts
+ for i, end_idx in enumerate(tokens_per_expert):
+ start_idx = 0 if i == 0 else tokens_per_expert[i - 1]
+ if start_idx == end_idx:
+ continue
+ expert = self.experts[i]
+ exp_token_idx = token_idxs[start_idx:end_idx]
+ expert_tokens = x[exp_token_idx]
+ expert_out = expert(expert_tokens)
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
+
+ # for fp16 and other dtype
+ expert_cache = expert_cache.to(expert_out.dtype)
+ expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
+ return expert_cache
+
+
+class TextProjection(nn.Module):
+ def __init__(self, in_features, hidden_size):
+ super().__init__()
+ self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
+
+ def forward(self, caption):
+ hidden_states = self.linear(caption)
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class HiDreamImageSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_routed_experts: int = 4,
+ num_activated_experts: int = 2,
+ _force_inference_output: bool = False,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
+
+ # 1. Attention
+ self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
+ self.attn1 = HiDreamAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ processor=HiDreamAttnProcessor(),
+ single=True,
+ )
+
+ # 3. Feed-forward
+ self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
+ if num_routed_experts > 0:
+ self.ff_i = MOEFeedForwardSwiGLU(
+ dim=dim,
+ hidden_dim=4 * dim,
+ num_routed_experts=num_routed_experts,
+ num_activated_experts=num_activated_experts,
+ _force_inference_output=_force_inference_output,
+ )
+ else:
+ self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ hidden_states_masks: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: torch.Tensor = None,
+ ) -> torch.Tensor:
+ wtype = hidden_states.dtype
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = self.adaLN_modulation(temb)[
+ :, None
+ ].chunk(6, dim=-1)
+
+ # 1. MM-Attention
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
+ attn_output_i = self.attn1(
+ norm_hidden_states,
+ hidden_states_masks,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
+
+ # 2. Feed-forward
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states.to(dtype=wtype))
+ hidden_states = ff_output_i + hidden_states
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class HiDreamImageTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_routed_experts: int = 4,
+ num_activated_experts: int = 2,
+ _force_inference_output: bool = False,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 12 * dim, bias=True))
+
+ # 1. Attention
+ self.norm1_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
+ self.norm1_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
+ self.attn1 = HiDreamAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ processor=HiDreamAttnProcessor(),
+ single=False,
+ )
+
+ # 3. Feed-forward
+ self.norm3_i = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
+ if num_routed_experts > 0:
+ self.ff_i = MOEFeedForwardSwiGLU(
+ dim=dim,
+ hidden_dim=4 * dim,
+ num_routed_experts=num_routed_experts,
+ num_activated_experts=num_activated_experts,
+ _force_inference_output=_force_inference_output,
+ )
+ else:
+ self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
+ self.norm3_t = nn.LayerNorm(dim, eps=1e-06, elementwise_affine=False)
+ self.ff_t = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ hidden_states_masks: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: torch.Tensor = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ wtype = hidden_states.dtype
+ (
+ shift_msa_i,
+ scale_msa_i,
+ gate_msa_i,
+ shift_mlp_i,
+ scale_mlp_i,
+ gate_mlp_i,
+ shift_msa_t,
+ scale_msa_t,
+ gate_msa_t,
+ shift_mlp_t,
+ scale_mlp_t,
+ gate_mlp_t,
+ ) = self.adaLN_modulation(temb)[:, None].chunk(12, dim=-1)
+
+ # 1. MM-Attention
+ norm_hidden_states = self.norm1_i(hidden_states).to(dtype=wtype)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa_i) + shift_msa_i
+ norm_encoder_hidden_states = self.norm1_t(encoder_hidden_states).to(dtype=wtype)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_msa_t) + shift_msa_t
+
+ attn_output_i, attn_output_t = self.attn1(
+ norm_hidden_states,
+ hidden_states_masks,
+ norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = gate_msa_i * attn_output_i + hidden_states
+ encoder_hidden_states = gate_msa_t * attn_output_t + encoder_hidden_states
+
+ # 2. Feed-forward
+ norm_hidden_states = self.norm3_i(hidden_states).to(dtype=wtype)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp_i) + shift_mlp_i
+ norm_encoder_hidden_states = self.norm3_t(encoder_hidden_states).to(dtype=wtype)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + scale_mlp_t) + shift_mlp_t
+
+ ff_output_i = gate_mlp_i * self.ff_i(norm_hidden_states)
+ ff_output_t = gate_mlp_t * self.ff_t(norm_encoder_hidden_states)
+ hidden_states = ff_output_i + hidden_states
+ encoder_hidden_states = ff_output_t + encoder_hidden_states
+ return hidden_states, encoder_hidden_states
+
+
+class HiDreamBlock(nn.Module):
+ def __init__(self, block: Union[HiDreamImageTransformerBlock, HiDreamImageSingleTransformerBlock]):
+ super().__init__()
+ self.block = block
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ hidden_states_masks: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: torch.Tensor = None,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ return self.block(
+ hidden_states=hidden_states,
+ hidden_states_masks=hidden_states_masks,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+
+class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Optional[int] = None,
+ in_channels: int = 64,
+ out_channels: Optional[int] = None,
+ num_layers: int = 16,
+ num_single_layers: int = 32,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 20,
+ caption_channels: List[int] = None,
+ text_emb_dim: int = 2048,
+ num_routed_experts: int = 4,
+ num_activated_experts: int = 2,
+ axes_dims_rope: Tuple[int, int] = (32, 32),
+ max_resolution: Tuple[int, int] = (128, 128),
+ llama_layers: List[int] = None,
+ force_inference_output: bool = False,
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.t_embedder = HiDreamImageTimestepEmbed(self.inner_dim)
+ self.p_embedder = HiDreamImagePooledEmbed(text_emb_dim, self.inner_dim)
+ self.x_embedder = HiDreamImagePatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ out_channels=self.inner_dim,
+ )
+ self.pe_embedder = HiDreamImageEmbedND(theta=10000, axes_dim=axes_dims_rope)
+
+ self.double_stream_blocks = nn.ModuleList(
+ [
+ HiDreamBlock(
+ HiDreamImageTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_routed_experts=num_routed_experts,
+ num_activated_experts=num_activated_experts,
+ _force_inference_output=force_inference_output,
+ )
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.single_stream_blocks = nn.ModuleList(
+ [
+ HiDreamBlock(
+ HiDreamImageSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_routed_experts=num_routed_experts,
+ num_activated_experts=num_activated_experts,
+ _force_inference_output=force_inference_output,
+ )
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ self.final_layer = HiDreamImageOutEmbed(self.inner_dim, patch_size, self.out_channels)
+
+ caption_channels = [caption_channels[1]] * (num_layers + num_single_layers) + [caption_channels[0]]
+ caption_projection = []
+ for caption_channel in caption_channels:
+ caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim))
+ self.caption_projection = nn.ModuleList(caption_projection)
+ self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
+
+ self.gradient_checkpointing = False
+
+ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
+ if is_training and not self.config.force_inference_output:
+ B, S, F = x.shape
+ C = F // (self.config.patch_size * self.config.patch_size)
+ x = (
+ x.reshape(B, S, self.config.patch_size, self.config.patch_size, C)
+ .permute(0, 4, 1, 2, 3)
+ .reshape(B, C, S, self.config.patch_size * self.config.patch_size)
+ )
+ else:
+ x_arr = []
+ p1 = self.config.patch_size
+ p2 = self.config.patch_size
+ for i, img_size in enumerate(img_sizes):
+ pH, pW = img_size
+ t = x[i, : pH * pW].reshape(1, pH, pW, -1)
+ F_token = t.shape[-1]
+ C = F_token // (p1 * p2)
+ t = t.reshape(1, pH, pW, p1, p2, C)
+ t = t.permute(0, 5, 1, 3, 2, 4)
+ t = t.reshape(1, C, pH * p1, pW * p2)
+ x_arr.append(t)
+ x = torch.cat(x_arr, dim=0)
+ return x
+
+ def patchify(self, hidden_states):
+ batch_size, channels, height, width = hidden_states.shape
+ patch_size = self.config.patch_size
+ patch_height, patch_width = height // patch_size, width // patch_size
+ device = hidden_states.device
+ dtype = hidden_states.dtype
+
+ # create img_sizes
+ img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
+ img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
+
+ # create hidden_states_masks
+ if hidden_states.shape[-2] != hidden_states.shape[-1]:
+ hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
+ hidden_states_masks[:, : patch_height * patch_width] = 1.0
+ else:
+ hidden_states_masks = None
+
+ # create img_ids
+ img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
+ row_indices = torch.arange(patch_height, device=device)[:, None]
+ col_indices = torch.arange(patch_width, device=device)[None, :]
+ img_ids[..., 1] = img_ids[..., 1] + row_indices
+ img_ids[..., 2] = img_ids[..., 2] + col_indices
+ img_ids = img_ids.reshape(patch_height * patch_width, -1)
+
+ if hidden_states.shape[-2] != hidden_states.shape[-1]:
+ # Handle non-square latents
+ img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
+ img_ids_pad[: patch_height * patch_width, :] = img_ids
+ img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
+ else:
+ img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
+
+ # patchify hidden_states
+ if hidden_states.shape[-2] != hidden_states.shape[-1]:
+ # Handle non-square latents
+ out = torch.zeros(
+ (batch_size, channels, self.max_seq, patch_size * patch_size),
+ dtype=dtype,
+ device=device,
+ )
+ hidden_states = hidden_states.reshape(
+ batch_size, channels, patch_height, patch_size, patch_width, patch_size
+ )
+ hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
+ hidden_states = hidden_states.reshape(
+ batch_size, channels, patch_height * patch_width, patch_size * patch_size
+ )
+ out[:, :, 0 : patch_height * patch_width] = hidden_states
+ hidden_states = out
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
+ batch_size, self.max_seq, patch_size * patch_size * channels
+ )
+
+ else:
+ # Handle square latents
+ hidden_states = hidden_states.reshape(
+ batch_size, channels, patch_height, patch_size, patch_width, patch_size
+ )
+ hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
+ hidden_states = hidden_states.reshape(
+ batch_size, patch_height * patch_width, patch_size * patch_size * channels
+ )
+
+ return hidden_states, hidden_states_masks, img_sizes, img_ids
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timesteps: torch.LongTensor = None,
+ encoder_hidden_states_t5: torch.Tensor = None,
+ encoder_hidden_states_llama3: torch.Tensor = None,
+ pooled_embeds: torch.Tensor = None,
+ img_ids: Optional[torch.Tensor] = None,
+ img_sizes: Optional[List[Tuple[int, int]]] = None,
+ hidden_states_masks: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
+
+ if encoder_hidden_states is not None:
+ deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
+ deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
+ encoder_hidden_states_t5 = encoder_hidden_states[0]
+ encoder_hidden_states_llama3 = encoder_hidden_states[1]
+
+ if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
+ deprecation_message = (
+ "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
+ )
+ deprecate("img_ids", "0.35.0", deprecation_message)
+
+ if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
+ raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
+ elif hidden_states_masks is not None and hidden_states.ndim != 3:
+ raise ValueError(
+ "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
+ )
+
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # spatial forward
+ batch_size = hidden_states.shape[0]
+ hidden_states_type = hidden_states.dtype
+
+ # Patchify the input
+ if hidden_states_masks is None:
+ hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
+
+ # Embed the hidden states
+ hidden_states = self.x_embedder(hidden_states)
+
+ # 0. time
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
+ p_embedder = self.p_embedder(pooled_embeds)
+ temb = timesteps + p_embedder
+
+ encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
+
+ if self.caption_projection is not None:
+ new_encoder_hidden_states = []
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
+ enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
+ new_encoder_hidden_states.append(enc_hidden_state)
+ encoder_hidden_states = new_encoder_hidden_states
+ encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
+ encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
+ encoder_hidden_states.append(encoder_hidden_states_t5)
+
+ txt_ids = torch.zeros(
+ batch_size,
+ encoder_hidden_states[-1].shape[1]
+ + encoder_hidden_states[-2].shape[1]
+ + encoder_hidden_states[0].shape[1],
+ 3,
+ device=img_ids.device,
+ dtype=img_ids.dtype,
+ )
+ ids = torch.cat((img_ids, txt_ids), dim=1)
+ image_rotary_emb = self.pe_embedder(ids)
+
+ # 2. Blocks
+ block_id = 0
+ initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
+ initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
+ for bid, block in enumerate(self.double_stream_blocks):
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
+ cur_encoder_hidden_states = torch.cat(
+ [initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1
+ )
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, initial_encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ hidden_states_masks,
+ cur_encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ )
+ else:
+ hidden_states, initial_encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ hidden_states_masks=hidden_states_masks,
+ encoder_hidden_states=cur_encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
+ block_id += 1
+
+ image_tokens_seq_len = hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
+ hidden_states_seq_len = hidden_states.shape[1]
+ if hidden_states_masks is not None:
+ encoder_attention_mask_ones = torch.ones(
+ (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
+ device=hidden_states_masks.device,
+ dtype=hidden_states_masks.dtype,
+ )
+ hidden_states_masks = torch.cat([hidden_states_masks, encoder_attention_mask_ones], dim=1)
+
+ for bid, block in enumerate(self.single_stream_blocks):
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
+ hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ hidden_states_masks,
+ None,
+ temb,
+ image_rotary_emb,
+ )
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ hidden_states_masks=hidden_states_masks,
+ encoder_hidden_states=None,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
+ block_id += 1
+
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
+ output = self.final_layer(hidden_states, temb)
+ output = self.unpatchify(output, img_sizes, self.training)
+ if hidden_states_masks is not None:
+ hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index 36f914f0b5c1..fb0ce1a30ff9 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,8 +23,9 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ..attention import FeedForward
-from ..attention_processor import Attention, AttentionProcessor
+from ..attention import AttentionMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepTextProjEmbeddings,
@@ -42,6 +43,9 @@
class HunyuanVideoAttnProcessor2_0:
+ _attention_backend = None
+ _parallel_config = None
+
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
@@ -64,9 +68,9 @@ def __call__(
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
# 2. QK normalization
if attn.norm_q is not None:
@@ -81,21 +85,29 @@ def __call__(
if attn.add_q_proj is None and encoder_hidden_states is not None:
query = torch.cat(
[
- apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
- query[:, :, -encoder_hidden_states.shape[1] :],
+ apply_rotary_emb(
+ query[:, : -encoder_hidden_states.shape[1]],
+ image_rotary_emb,
+ sequence_dim=1,
+ ),
+ query[:, -encoder_hidden_states.shape[1] :],
],
- dim=2,
+ dim=1,
)
key = torch.cat(
[
- apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
- key[:, :, -encoder_hidden_states.shape[1] :],
+ apply_rotary_emb(
+ key[:, : -encoder_hidden_states.shape[1]],
+ image_rotary_emb,
+ sequence_dim=1,
+ ),
+ key[:, -encoder_hidden_states.shape[1] :],
],
- dim=2,
+ dim=1,
)
else:
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# 4. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
@@ -103,24 +115,31 @@ def __call__(
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
- query = torch.cat([query, encoder_query], dim=2)
- key = torch.cat([key, encoder_key], dim=2)
- value = torch.cat([value, encoder_value], dim=2)
+ query = torch.cat([query, encoder_query], dim=1)
+ key = torch.cat([key, encoder_key], dim=1)
+ value = torch.cat([value, encoder_value], dim=1)
# 5. Attention
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
@@ -529,7 +548,7 @@ def forward(
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -684,7 +703,7 @@ def forward(
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -819,7 +838,9 @@ def forward(
return hidden_states, encoder_hidden_states
-class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+class HunyuanVideoTransformer3DModel(
+ ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
+):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -870,6 +891,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
+ _repeated_blocks = [
+ "HunyuanVideoTransformerBlock",
+ "HunyuanVideoSingleTransformerBlock",
+ "HunyuanVideoPatchEmbed",
+ "HunyuanVideoTokenRefiner",
+ ]
@register_to_config
def __init__(
@@ -889,7 +916,7 @@ def __init__(
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
- rope_axes_dim: Tuple[int] = (16, 56, 56),
+ rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
image_condition_type: Optional[str] = None,
) -> None:
super().__init__()
@@ -962,66 +989,6 @@ def __init__(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -1032,7 +999,7 @@ def forward(
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -1068,17 +1035,15 @@ def forward(
latent_sequence_length = hidden_states.shape[1]
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
- attention_mask = torch.zeros(
+ attention_mask = torch.ones(
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
-
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
-
- for i in range(batch_size):
- attention_mask[i, : effective_sequence_length[i]] = True
- # [B, 1, 1, N], for broadcasting across attention heads
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
+ indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
+ mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
+ attention_mask = attention_mask.masked_fill(mask_indices, False)
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py
new file mode 100644
index 000000000000..293ba996ea98
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py
@@ -0,0 +1,793 @@
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.loaders import FromOriginalModelMixin
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ CombinedTimestepTextProjEmbeddings,
+ TimestepEmbedding,
+ Timesteps,
+ get_1d_rotary_pos_embed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanVideo15AttnProcessor2_0:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "HunyuanVideo15AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # 1. QKV projections
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ # 2. QK normalization
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ # 3. Rotational positional embeddings applied to latent stream
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ # 4. Encoder condition QKV projection and normalization
+ if encoder_hidden_states is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([query, encoder_query], dim=1)
+ key = torch.cat([key, encoder_key], dim=1)
+ value = torch.cat([value, encoder_value], dim=1)
+
+ batch_size, seq_len, heads, dim = query.shape
+ attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True)
+ attention_mask = attention_mask.bool()
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
+ attention_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
+
+ # 5. Attention
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # 6. Output projection
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
+ )
+
+ if getattr(attn, "to_out", None) is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if getattr(attn, "to_add_out", None) is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanVideo15PatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: Union[int, Tuple[int, int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ super().__init__()
+
+ patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
+ return hidden_states
+
+
+class HunyuanVideo15AdaNorm(nn.Module):
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
+ super().__init__()
+
+ out_features = out_features or 2 * in_features
+ self.linear = nn.Linear(in_features, out_features)
+ self.nonlinearity = nn.SiLU()
+
+ def forward(
+ self, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ temb = self.linear(self.nonlinearity(temb))
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
+ return gate_msa, gate_mlp
+
+
+class HunyuanVideo15TimeEmbedding(nn.Module):
+ r"""
+ Time embedding for HunyuanVideo 1.5.
+
+ Supports standard timestep embedding and optional reference timestep embedding for MeanFlow-based super-resolution
+ models.
+
+ Args:
+ embedding_dim (`int`):
+ The dimension of the output embedding.
+ """
+
+ def __init__(self, embedding_dim: int, use_meanflow: bool = False):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_meanflow = use_meanflow
+ self.time_proj_r = None
+ self.timestep_embedder_r = None
+ if use_meanflow:
+ self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ timestep_r: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
+
+ if timestep_r is not None:
+ timesteps_proj_r = self.time_proj_r(timestep_r)
+ timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
+ timesteps_emb = timesteps_emb + timesteps_emb_r
+
+ return timesteps_emb
+
+
+class HunyuanVideo15IndividualTokenRefinerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_width_ratio: str = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
+
+ self.norm_out = HunyuanVideo15AdaNorm(hidden_size, 2 * hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ )
+
+ gate_msa, gate_mlp = self.norm_out(temb)
+ hidden_states = hidden_states + attn_output * gate_msa
+
+ ff_output = self.ff(self.norm2(hidden_states))
+ hidden_states = hidden_states + ff_output * gate_mlp
+
+ return hidden_states
+
+
+class HunyuanVideo15IndividualTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.refiner_blocks = nn.ModuleList(
+ [
+ HunyuanVideo15IndividualTokenRefinerBlock(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> None:
+ self_attn_mask = None
+ if attention_mask is not None:
+ batch_size = attention_mask.shape[0]
+ seq_len = attention_mask.shape[1]
+ attention_mask = attention_mask.to(hidden_states.device).bool()
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
+
+ for block in self.refiner_blocks:
+ hidden_states = block(hidden_states, temb, self_attn_mask)
+
+ return hidden_states
+
+
+class HunyuanVideo15TokenRefiner(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
+ )
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
+ self.token_refiner = HunyuanVideo15IndividualTokenRefiner(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_layers=num_layers,
+ mlp_width_ratio=mlp_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+ if attention_mask is None:
+ pooled_projections = hidden_states.mean(dim=1)
+ else:
+ original_dtype = hidden_states.dtype
+ mask_float = attention_mask.float().unsqueeze(-1)
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
+ pooled_projections = pooled_projections.to(original_dtype)
+
+ temb = self.time_text_embed(timestep, pooled_projections)
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
+
+ return hidden_states
+
+
+class HunyuanVideo15RotaryPosEmbed(nn.Module):
+ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.rope_dim = rope_dim
+ self.theta = theta
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
+
+ axes_grids = []
+ for i in range(len(rope_sizes)):
+ # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
+ # original implementation creates it on CPU and then moves it to device. This results in numerical
+ # differences in layerwise debugging outputs, but visually it is the same.
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
+ axes_grids.append(grid)
+ grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
+
+ freqs = []
+ for i in range(3):
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
+ freqs.append(freq)
+
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
+ return freqs_cos, freqs_sin
+
+
+class HunyuanVideo15ByT5TextProjection(nn.Module):
+ def __init__(self, in_features: int, hidden_size: int, out_features: int):
+ super().__init__()
+ self.norm = nn.LayerNorm(in_features)
+ self.linear_1 = nn.Linear(in_features, hidden_size)
+ self.linear_2 = nn.Linear(hidden_size, hidden_size)
+ self.linear_3 = nn.Linear(hidden_size, out_features)
+ self.act_fn = nn.GELU()
+
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm(encoder_hidden_states)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_3(hidden_states)
+ return hidden_states
+
+
+class HunyuanVideo15ImageProjection(nn.Module):
+ def __init__(self, in_channels: int, hidden_size: int):
+ super().__init__()
+ self.norm_in = nn.LayerNorm(in_channels)
+ self.linear_1 = nn.Linear(in_channels, in_channels)
+ self.act_fn = nn.GELU()
+ self.linear_2 = nn.Linear(in_channels, hidden_size)
+ self.norm_out = nn.LayerNorm(hidden_size)
+
+ def forward(self, image_embeds: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm_in(image_embeds)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ hidden_states = self.norm_out(hidden_states)
+ return hidden_states
+
+
+class HunyuanVideo15TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ added_kv_proj_dim=hidden_size,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ context_pre_only=False,
+ bias=True,
+ processor=HunyuanVideo15AttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=freqs_cis,
+ )
+
+ # 3. Modulation and residual connection
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # 4. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanVideo15Transformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in [HunyuanVideo1.5](https://huggingface.co/tencent/HunyuanVideo1.5).
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of dual-stream blocks to use.
+ num_refiner_layers (`int`, defaults to `2`):
+ The number of layers of refiner blocks to use.
+ mlp_ratio (`float`, defaults to `4.0`):
+ The ratio of the hidden layer size to the input size in the feedforward network.
+ patch_size (`int`, defaults to `2`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the tmeporal patches to use in the patch embedding layer.
+ qk_norm (`str`, defaults to `rms_norm`):
+ The normalization to use for the query and key projections in the attention layers.
+ guidance_embeds (`bool`, defaults to `True`):
+ Whether to use guidance embeddings in the model.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ pooled_projection_dim (`int`, defaults to `768`):
+ The dimension of the pooled projection of the text embeddings.
+ rope_theta (`float`, defaults to `256.0`):
+ The value of theta to use in the RoPE layer.
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions of the axes to use in the RoPE layer.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
+ _no_split_modules = [
+ "HunyuanVideo15TransformerBlock",
+ "HunyuanVideo15PatchEmbed",
+ "HunyuanVideo15TokenRefiner",
+ ]
+ _repeated_blocks = [
+ "HunyuanVideo15TransformerBlock",
+ "HunyuanVideo15PatchEmbed",
+ "HunyuanVideo15TokenRefiner",
+ ]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 65,
+ out_channels: int = 32,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 128,
+ num_layers: int = 54,
+ num_refiner_layers: int = 2,
+ mlp_ratio: float = 4.0,
+ patch_size: int = 1,
+ patch_size_t: int = 1,
+ qk_norm: str = "rms_norm",
+ text_embed_dim: int = 3584,
+ text_embed_2_dim: int = 1472,
+ image_embed_dim: int = 1152,
+ rope_theta: float = 256.0,
+ rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
+ # YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205
+ target_size: int = 640, # did not name sample_size since it is in pixel spaces
+ task_type: str = "i2v",
+ use_meanflow: bool = False,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Latent and condition embedders
+ self.x_embedder = HunyuanVideo15PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+ self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim)
+
+ self.context_embedder = HunyuanVideo15TokenRefiner(
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
+ )
+ self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
+
+ self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow=use_meanflow)
+
+ self.cond_type_embed = nn.Embedding(3, inner_dim)
+
+ # 2. RoPE
+ self.rope = HunyuanVideo15RotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
+
+ # 3. Dual stream transformer blocks
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideo15TransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 5. Output projection
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ timestep_r: Optional[torch.LongTensor] = None,
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
+ encoder_attention_mask_2: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ # 1. RoPE
+ image_rotary_emb = self.rope(hidden_states)
+
+ # 2. Conditional embeddings
+ temb = self.time_embed(timestep, timestep_r=timestep_r)
+
+ hidden_states = self.x_embedder(hidden_states)
+
+ # qwen text embedding
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
+
+ encoder_hidden_states_cond_emb = self.cond_type_embed(
+ torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long)
+ )
+ encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb
+
+ # byt5 text embedding
+ encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
+
+ encoder_hidden_states_2_cond_emb = self.cond_type_embed(
+ torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long)
+ )
+ encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb
+
+ # image embed
+ encoder_hidden_states_3 = self.image_embedder(image_embeds)
+ is_t2v = torch.all(image_embeds == 0)
+ if is_t2v:
+ encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0
+ encoder_attention_mask_3 = torch.zeros(
+ (batch_size, encoder_hidden_states_3.shape[1]),
+ dtype=encoder_attention_mask.dtype,
+ device=encoder_attention_mask.device,
+ )
+ else:
+ encoder_attention_mask_3 = torch.ones(
+ (batch_size, encoder_hidden_states_3.shape[1]),
+ dtype=encoder_attention_mask.dtype,
+ device=encoder_attention_mask.device,
+ )
+ encoder_hidden_states_3_cond_emb = self.cond_type_embed(
+ 2
+ * torch.ones_like(
+ encoder_hidden_states_3[:, :, 0],
+ dtype=torch.long,
+ )
+ )
+ encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb
+
+ # reorder and combine text tokens: combine valid tokens first, then padding
+ encoder_attention_mask = encoder_attention_mask.bool()
+ encoder_attention_mask_2 = encoder_attention_mask_2.bool()
+ encoder_attention_mask_3 = encoder_attention_mask_3.bool()
+ new_encoder_hidden_states = []
+ new_encoder_attention_mask = []
+
+ for text, text_mask, text_2, text_mask_2, image, image_mask in zip(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ encoder_hidden_states_2,
+ encoder_attention_mask_2,
+ encoder_hidden_states_3,
+ encoder_attention_mask_3,
+ ):
+ # Concatenate: [valid_image, valid_byt5, valid_mllm, invalid_image, invalid_byt5, invalid_mllm]
+ new_encoder_hidden_states.append(
+ torch.cat(
+ [
+ image[image_mask], # valid image
+ text_2[text_mask_2], # valid byt5
+ text[text_mask], # valid mllm
+ image[~image_mask], # invalid image
+ torch.zeros_like(text_2[~text_mask_2]), # invalid byt5 (zeroed)
+ torch.zeros_like(text[~text_mask]), # invalid mllm (zeroed)
+ ],
+ dim=0,
+ )
+ )
+
+ # Apply same reordering to attention masks
+ new_encoder_attention_mask.append(
+ torch.cat(
+ [
+ image_mask[image_mask],
+ text_mask_2[text_mask_2],
+ text_mask[text_mask],
+ image_mask[~image_mask],
+ text_mask_2[~text_mask_2],
+ text_mask[~text_mask],
+ ],
+ dim=0,
+ )
+ )
+
+ encoder_hidden_states = torch.stack(new_encoder_hidden_states)
+ encoder_attention_mask = torch.stack(new_encoder_attention_mask)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ encoder_attention_mask,
+ image_rotary_emb,
+ )
+
+ else:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ encoder_attention_mask,
+ image_rotary_emb,
+ )
+
+ # 5. Output projection
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p_h, p_w
+ )
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(sample=hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
new file mode 100644
index 000000000000..601ba0f0b472
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
@@ -0,0 +1,416 @@
+# Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
+from ..cache_utils import CacheMixin
+from ..embeddings import get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous
+from .transformer_hunyuan_video import (
+ HunyuanVideoConditionEmbedding,
+ HunyuanVideoPatchEmbed,
+ HunyuanVideoSingleTransformerBlock,
+ HunyuanVideoTokenRefiner,
+ HunyuanVideoTransformerBlock,
+)
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
+ def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.rope_dim = rope_dim
+ self.theta = theta
+
+ def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
+ height = height // self.patch_size
+ width = width // self.patch_size
+ grid = torch.meshgrid(
+ frame_indices.to(device=device, dtype=torch.float32),
+ torch.arange(0, height, device=device, dtype=torch.float32),
+ torch.arange(0, width, device=device, dtype=torch.float32),
+ indexing="ij",
+ ) # 3 * [W, H, T]
+ grid = torch.stack(grid, dim=0) # [3, W, H, T]
+
+ freqs = []
+ for i in range(3):
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
+ freqs.append(freq)
+
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
+
+ return freqs_cos, freqs_sin
+
+
+class FramepackClipVisionProjection(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ self.up = nn.Linear(in_channels, out_channels * 3)
+ self.down = nn.Linear(out_channels * 3, out_channels)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.up(hidden_states)
+ hidden_states = F.silu(hidden_states)
+ hidden_states = self.down(hidden_states)
+ return hidden_states
+
+
+class HunyuanVideoHistoryPatchEmbed(nn.Module):
+ def __init__(self, in_channels: int, inner_dim: int):
+ super().__init__()
+ self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
+ self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
+ self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
+
+ def forward(
+ self,
+ latents_clean: Optional[torch.Tensor] = None,
+ latents_clean_2x: Optional[torch.Tensor] = None,
+ latents_clean_4x: Optional[torch.Tensor] = None,
+ ):
+ if latents_clean is not None:
+ latents_clean = self.proj(latents_clean)
+ latents_clean = latents_clean.flatten(2).transpose(1, 2)
+ if latents_clean_2x is not None:
+ latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
+ latents_clean_2x = self.proj_2x(latents_clean_2x)
+ latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
+ if latents_clean_4x is not None:
+ latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
+ latents_clean_4x = self.proj_4x(latents_clean_4x)
+ latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
+ return latents_clean, latents_clean_2x, latents_clean_4x
+
+
+class HunyuanVideoFramepackTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
+):
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
+ _no_split_modules = [
+ "HunyuanVideoTransformerBlock",
+ "HunyuanVideoSingleTransformerBlock",
+ "HunyuanVideoHistoryPatchEmbed",
+ "HunyuanVideoTokenRefiner",
+ ]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ num_attention_heads: int = 24,
+ attention_head_dim: int = 128,
+ num_layers: int = 20,
+ num_single_layers: int = 40,
+ num_refiner_layers: int = 2,
+ mlp_ratio: float = 4.0,
+ patch_size: int = 2,
+ patch_size_t: int = 1,
+ qk_norm: str = "rms_norm",
+ guidance_embeds: bool = True,
+ text_embed_dim: int = 4096,
+ pooled_projection_dim: int = 768,
+ rope_theta: float = 256.0,
+ rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
+ image_condition_type: Optional[str] = None,
+ has_image_proj: int = False,
+ image_proj_dim: int = 1152,
+ has_clean_x_embedder: int = False,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Latent and condition embedders
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
+
+ # Framepack history projection embedder
+ self.clean_x_embedder = None
+ if has_clean_x_embedder:
+ self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
+
+ self.context_embedder = HunyuanVideoTokenRefiner(
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
+ )
+
+ # Framepack image-conditioning embedder
+ self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
+
+ self.time_text_embed = HunyuanVideoConditionEmbedding(
+ inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
+ )
+
+ # 2. RoPE
+ self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
+
+ # 3. Dual stream transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideoTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Single stream transformer blocks
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ HunyuanVideoSingleTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ # 5. Output projection
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ pooled_projections: torch.Tensor,
+ image_embeds: torch.Tensor,
+ indices_latents: torch.Tensor,
+ guidance: Optional[torch.Tensor] = None,
+ latents_clean: Optional[torch.Tensor] = None,
+ indices_latents_clean: Optional[torch.Tensor] = None,
+ latents_history_2x: Optional[torch.Tensor] = None,
+ indices_latents_history_2x: Optional[torch.Tensor] = None,
+ latents_history_4x: Optional[torch.Tensor] = None,
+ indices_latents_history_4x: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p, p_t = self.config.patch_size, self.config.patch_size_t
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p
+ post_patch_width = width // p
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
+
+ if indices_latents is None:
+ indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
+
+ hidden_states = self.x_embedder(hidden_states)
+ image_rotary_emb = self.rope(
+ frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
+ )
+
+ latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
+ latents_clean, latents_history_2x, latents_history_4x
+ )
+
+ if latents_clean is not None and indices_latents_clean is not None:
+ image_rotary_emb_clean = self.rope(
+ frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
+ )
+ if latents_history_2x is not None and indices_latents_history_2x is not None:
+ image_rotary_emb_history_2x = self.rope(
+ frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
+ )
+ if latents_history_4x is not None and indices_latents_history_4x is not None:
+ image_rotary_emb_history_4x = self.rope(
+ frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
+ )
+
+ hidden_states, image_rotary_emb = self._pack_history_states(
+ hidden_states,
+ latents_clean,
+ latents_history_2x,
+ latents_history_4x,
+ image_rotary_emb,
+ image_rotary_emb_clean,
+ image_rotary_emb_history_2x,
+ image_rotary_emb_history_4x,
+ post_patch_height,
+ post_patch_width,
+ )
+
+ temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
+
+ encoder_hidden_states_image = self.image_projection(image_embeds)
+ attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
+
+ # must cat before (not after) encoder_hidden_states, due to attn masking
+ encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+ encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
+
+ latent_sequence_length = hidden_states.shape[1]
+ condition_sequence_length = encoder_hidden_states.shape[1]
+ sequence_length = latent_sequence_length + condition_sequence_length
+ attention_mask = torch.zeros(
+ batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
+ ) # [B, N]
+ effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
+ effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
+
+ if batch_size == 1:
+ encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
+ attention_mask = None
+ else:
+ for i in range(batch_size):
+ attention_mask[i, : effective_sequence_length[i]] = True
+ # [B, 1, 1, N], for broadcasting across attention heads
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
+ )
+
+ for block in self.single_transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
+ )
+
+ else:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
+ )
+
+ for block in self.single_transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
+ )
+
+ hidden_states = hidden_states[:, -original_context_length:]
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
+ )
+ hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
+ hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (hidden_states,)
+ return Transformer2DModelOutput(sample=hidden_states)
+
+ def _pack_history_states(
+ self,
+ hidden_states: torch.Tensor,
+ latents_clean: Optional[torch.Tensor] = None,
+ latents_history_2x: Optional[torch.Tensor] = None,
+ latents_history_4x: Optional[torch.Tensor] = None,
+ image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
+ image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ height: int = None,
+ width: int = None,
+ ):
+ image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
+
+ if latents_clean is not None and image_rotary_emb_clean is not None:
+ hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
+
+ if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
+ hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
+ image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
+
+ if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
+ hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
+ image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
+ image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
+ image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
+
+ return hidden_states, tuple(image_rotary_emb)
+
+ def _pad_rotary_emb(
+ self,
+ image_rotary_emb: Tuple[torch.Tensor],
+ height: int,
+ width: int,
+ kernel_size: Tuple[int, int, int],
+ ):
+ # freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
+ freqs_cos, freqs_sin = image_rotary_emb
+ freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
+ freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
+ freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
+ freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
+ freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
+ freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
+ freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
+ freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
+ return freqs_cos, freqs_sin
+
+
+def _pad_for_3d_conv(x, kernel_size):
+ if isinstance(x, (tuple, list)):
+ return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
+ b, c, t, h, w = x.shape
+ pt, ph, pw = kernel_size
+ pad_t = (pt - (t % pt)) % pt
+ pad_h = (ph - (h % ph)) % ph
+ pad_w = (pw - (w % pw)) % pw
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
+
+
+def _center_down_sample_3d(x, kernel_size):
+ if isinstance(x, (tuple, list)):
+ return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
diff --git a/src/diffusers/models/transformers/transformer_hunyuanimage.py b/src/diffusers/models/transformers/transformer_hunyuanimage.py
new file mode 100644
index 000000000000..d626e322ad6f
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_hunyuanimage.py
@@ -0,0 +1,910 @@
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.loaders import FromOriginalModelMixin
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ CombinedTimestepTextProjEmbeddings,
+ TimestepEmbedding,
+ Timesteps,
+ get_1d_rotary_pos_embed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanImageAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "HunyuanImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ # 1. QKV projections
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)) # batch_size, seq_len, heads, head_dim
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ # 2. QK normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # 3. Rotational positional embeddings applied to latent stream
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ query = torch.cat(
+ [
+ apply_rotary_emb(
+ query[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1
+ ),
+ query[:, -encoder_hidden_states.shape[1] :],
+ ],
+ dim=1,
+ )
+ key = torch.cat(
+ [
+ apply_rotary_emb(key[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1),
+ key[:, -encoder_hidden_states.shape[1] :],
+ ],
+ dim=1,
+ )
+ else:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ # 4. Encoder condition QKV projection and normalization
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([query, encoder_query], dim=1)
+ key = torch.cat([key, encoder_key], dim=1)
+ value = torch.cat([value, encoder_value], dim=1)
+
+ # 5. Attention
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # 6. Output projection
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
+ )
+
+ if getattr(attn, "to_out", None) is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if getattr(attn, "to_add_out", None) is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanImagePatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: Union[Tuple[int, int], Tuple[int, int, int]] = (16, 16),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+
+ if len(patch_size) == 2:
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ elif len(patch_size) == 3:
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ else:
+ raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {len(patch_size)}")
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+ return hidden_states
+
+
+class HunyuanImageByT5TextProjection(nn.Module):
+ def __init__(self, in_features: int, hidden_size: int, out_features: int):
+ super().__init__()
+ self.norm = nn.LayerNorm(in_features)
+ self.linear_1 = nn.Linear(in_features, hidden_size)
+ self.linear_2 = nn.Linear(hidden_size, hidden_size)
+ self.linear_3 = nn.Linear(hidden_size, out_features)
+ self.act_fn = nn.GELU()
+
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm(encoder_hidden_states)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_3(hidden_states)
+ return hidden_states
+
+
+class HunyuanImageAdaNorm(nn.Module):
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
+ super().__init__()
+
+ out_features = out_features or 2 * in_features
+ self.linear = nn.Linear(in_features, out_features)
+ self.nonlinearity = nn.SiLU()
+
+ def forward(
+ self, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ temb = self.linear(self.nonlinearity(temb))
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
+ return gate_msa, gate_mlp
+
+
+class HunyuanImageCombinedTimeGuidanceEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ guidance_embeds: bool = False,
+ use_meanflow: bool = False,
+ ):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_meanflow = use_meanflow
+
+ self.time_proj_r = None
+ self.timestep_embedder_r = None
+ if use_meanflow:
+ self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.guidance_embedder = None
+ if guidance_embeds:
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ timestep_r: Optional[torch.Tensor] = None,
+ guidance: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
+
+ if timestep_r is not None:
+ timesteps_proj_r = self.time_proj_r(timestep_r)
+ timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
+ timesteps_emb = (timesteps_emb + timesteps_emb_r) / 2
+
+ if self.guidance_embedder is not None:
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=timestep.dtype))
+ conditioning = timesteps_emb + guidance_emb
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+# IndividualTokenRefinerBlock
+@maybe_allow_in_graph
+class HunyuanImageIndividualTokenRefinerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int, # 28
+ attention_head_dim: int, # 128
+ mlp_width_ratio: str = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
+
+ self.norm_out = HunyuanImageAdaNorm(hidden_size, 2 * hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ )
+
+ gate_msa, gate_mlp = self.norm_out(temb)
+ hidden_states = hidden_states + attn_output * gate_msa
+
+ ff_output = self.ff(self.norm2(hidden_states))
+ hidden_states = hidden_states + ff_output * gate_mlp
+
+ return hidden_states
+
+
+class HunyuanImageIndividualTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.refiner_blocks = nn.ModuleList(
+ [
+ HunyuanImageIndividualTokenRefinerBlock(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> None:
+ self_attn_mask = None
+ if attention_mask is not None:
+ batch_size = attention_mask.shape[0]
+ seq_len = attention_mask.shape[1]
+ attention_mask = attention_mask.to(hidden_states.device)
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
+ self_attn_mask[:, :, :, 0] = True
+
+ for block in self.refiner_blocks:
+ hidden_states = block(hidden_states, temb, self_attn_mask)
+
+ return hidden_states
+
+
+# txt_in
+class HunyuanImageTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
+ )
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
+ self.token_refiner = HunyuanImageIndividualTokenRefiner(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_layers=num_layers,
+ mlp_width_ratio=mlp_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+ if attention_mask is None:
+ pooled_hidden_states = hidden_states.mean(dim=1)
+ else:
+ original_dtype = hidden_states.dtype
+ mask_float = attention_mask.float().unsqueeze(-1)
+ pooled_hidden_states = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
+ pooled_hidden_states = pooled_hidden_states.to(original_dtype)
+
+ temb = self.time_text_embed(timestep, pooled_hidden_states)
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
+
+ return hidden_states
+
+
+class HunyuanImageRotaryPosEmbed(nn.Module):
+ def __init__(
+ self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0
+ ) -> None:
+ super().__init__()
+
+ if not isinstance(patch_size, (tuple, list)) or len(patch_size) not in [2, 3]:
+ raise ValueError(f"patch_size must be a tuple or list of length 2 or 3, got {patch_size}")
+
+ if not isinstance(rope_dim, (tuple, list)) or len(rope_dim) not in [2, 3]:
+ raise ValueError(f"rope_dim must be a tuple or list of length 2 or 3, got {rope_dim}")
+
+ if not len(patch_size) == len(rope_dim):
+ raise ValueError(f"patch_size and rope_dim must have the same length, got {patch_size} and {rope_dim}")
+
+ self.patch_size = patch_size
+ self.rope_dim = rope_dim
+ self.theta = theta
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if hidden_states.ndim == 5:
+ _, _, frame, height, width = hidden_states.shape
+ patch_size_frame, patch_size_height, patch_size_width = self.patch_size
+ rope_sizes = [frame // patch_size_frame, height // patch_size_height, width // patch_size_width]
+ elif hidden_states.ndim == 4:
+ _, _, height, width = hidden_states.shape
+ patch_size_height, patch_size_width = self.patch_size
+ rope_sizes = [height // patch_size_height, width // patch_size_width]
+ else:
+ raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
+
+ axes_grids = []
+ for i in range(len(rope_sizes)):
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
+ axes_grids.append(grid)
+ grid = torch.meshgrid(*axes_grids, indexing="ij") # dim x [H, W]
+ grid = torch.stack(grid, dim=0) # [2, H, W]
+
+ freqs = []
+ for i in range(len(rope_sizes)):
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
+ freqs.append(freq)
+
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
+class HunyuanImageSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+ mlp_dim = int(hidden_size * mlp_ratio)
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ bias=True,
+ processor=HunyuanImageAttnProcessor(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ residual = hidden_states
+
+ # 1. Input normalization
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ norm_hidden_states, norm_encoder_hidden_states = (
+ norm_hidden_states[:, :-text_seq_length, :],
+ norm_hidden_states[:, -text_seq_length:, :],
+ )
+
+ # 2. Attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
+
+ # 3. Modulation and residual connection
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
+ hidden_states = hidden_states + residual
+
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, :-text_seq_length, :],
+ hidden_states[:, -text_seq_length:, :],
+ )
+ return hidden_states, encoder_hidden_states
+
+
+@maybe_allow_in_graph
+class HunyuanImageTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ added_kv_proj_dim=hidden_size,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ context_pre_only=False,
+ bias=True,
+ processor=HunyuanImageAttnProcessor(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 3. Modulation and residual connection
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # 4. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanImageTransformer2DModel(
+ ModelMixin, ConfigMixin, AttentionMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
+):
+ r"""
+ The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of dual-stream blocks to use.
+ num_single_layers (`int`, defaults to `40`):
+ The number of layers of single-stream blocks to use.
+ num_refiner_layers (`int`, defaults to `2`):
+ The number of layers of refiner blocks to use.
+ mlp_ratio (`float`, defaults to `4.0`):
+ The ratio of the hidden layer size to the input size in the feedforward network.
+ patch_size (`int`, defaults to `2`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the tmeporal patches to use in the patch embedding layer.
+ qk_norm (`str`, defaults to `rms_norm`):
+ The normalization to use for the query and key projections in the attention layers.
+ guidance_embeds (`bool`, defaults to `True`):
+ Whether to use guidance embeddings in the model.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ pooled_projection_dim (`int`, defaults to `768`):
+ The dimension of the pooled projection of the text embeddings.
+ rope_theta (`float`, defaults to `256.0`):
+ The value of theta to use in the RoPE layer.
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions of the axes to use in the RoPE layer.
+ image_condition_type (`str`, *optional*, defaults to `None`):
+ The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
+ image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
+ tokens in the latent stream and apply conditioning.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
+ _no_split_modules = [
+ "HunyuanImageTransformerBlock",
+ "HunyuanImageSingleTransformerBlock",
+ "HunyuanImagePatchEmbed",
+ "HunyuanImageTokenRefiner",
+ ]
+ _repeated_blocks = ["HunyuanImageTransformerBlock", "HunyuanImageSingleTransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 64,
+ out_channels: int = 64,
+ num_attention_heads: int = 28,
+ attention_head_dim: int = 128,
+ num_layers: int = 20,
+ num_single_layers: int = 40,
+ num_refiner_layers: int = 2,
+ mlp_ratio: float = 4.0,
+ patch_size: Tuple[int, int] = (1, 1),
+ qk_norm: str = "rms_norm",
+ guidance_embeds: bool = False,
+ text_embed_dim: int = 3584,
+ text_embed_2_dim: Optional[int] = None,
+ rope_theta: float = 256.0,
+ rope_axes_dim: Tuple[int, ...] = (64, 64),
+ use_meanflow: bool = False,
+ ) -> None:
+ super().__init__()
+
+ if not (isinstance(patch_size, (tuple, list)) and len(patch_size) in [2, 3]):
+ raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {patch_size}")
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Latent and condition embedders
+ self.x_embedder = HunyuanImagePatchEmbed(patch_size, in_channels, inner_dim)
+ self.context_embedder = HunyuanImageTokenRefiner(
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
+ )
+
+ if text_embed_2_dim is not None:
+ self.context_embedder_2 = HunyuanImageByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
+ else:
+ self.context_embedder_2 = None
+
+ self.time_guidance_embed = HunyuanImageCombinedTimeGuidanceEmbedding(inner_dim, guidance_embeds, use_meanflow)
+
+ # 2. RoPE
+ self.rope = HunyuanImageRotaryPosEmbed(patch_size, rope_axes_dim, rope_theta)
+
+ # 3. Dual stream transformer blocks
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ HunyuanImageTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Single stream transformer blocks
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ HunyuanImageSingleTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ # 5. Output projection
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ timestep_r: Optional[torch.LongTensor] = None,
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
+ encoder_attention_mask_2: Optional[torch.Tensor] = None,
+ guidance: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ if hidden_states.ndim == 4:
+ batch_size, channels, height, width = hidden_states.shape
+ sizes = (height, width)
+ elif hidden_states.ndim == 5:
+ batch_size, channels, frame, height, width = hidden_states.shape
+ sizes = (frame, height, width)
+ else:
+ raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
+
+ post_patch_sizes = tuple(d // p for d, p in zip(sizes, self.config.patch_size))
+
+ # 1. RoPE
+ image_rotary_emb = self.rope(hidden_states)
+
+ # 2. Conditional embeddings
+ encoder_attention_mask = encoder_attention_mask.bool()
+ temb = self.time_guidance_embed(timestep, guidance=guidance, timestep_r=timestep_r)
+ hidden_states = self.x_embedder(hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
+
+ if self.context_embedder_2 is not None and encoder_hidden_states_2 is not None:
+ encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
+
+ encoder_attention_mask_2 = encoder_attention_mask_2.bool()
+
+ # reorder and combine text tokens: combine valid tokens first, then padding
+ new_encoder_hidden_states = []
+ new_encoder_attention_mask = []
+
+ for text, text_mask, text_2, text_mask_2 in zip(
+ encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2
+ ):
+ # Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
+ new_encoder_hidden_states.append(
+ torch.cat(
+ [
+ text_2[text_mask_2], # valid byt5
+ text[text_mask], # valid mllm
+ text_2[~text_mask_2], # invalid byt5
+ text[~text_mask], # invalid mllm
+ ],
+ dim=0,
+ )
+ )
+
+ # Apply same reordering to attention masks
+ new_encoder_attention_mask.append(
+ torch.cat(
+ [
+ text_mask_2[text_mask_2],
+ text_mask[text_mask],
+ text_mask_2[~text_mask_2],
+ text_mask[~text_mask],
+ ],
+ dim=0,
+ )
+ )
+
+ encoder_hidden_states = torch.stack(new_encoder_hidden_states)
+ encoder_attention_mask = torch.stack(new_encoder_attention_mask)
+
+ attention_mask = torch.nn.functional.pad(encoder_attention_mask, (hidden_states.shape[1], 0), value=True)
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ # 3. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ for block in self.single_transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ else:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ for block in self.single_transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 4. Output projection
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. unpatchify
+ # reshape: [batch_size, *post_patch_dims, channels, *patch_size]
+ out_channels = self.config.out_channels
+ reshape_dims = [batch_size] + list(post_patch_sizes) + [out_channels] + list(self.config.patch_size)
+ hidden_states = hidden_states.reshape(*reshape_dims)
+
+ # create permutation pattern: batch, channels, then interleave post_patch and patch dims
+ # For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
+ # For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
+ ndim = len(post_patch_sizes)
+ permute_pattern = [0, ndim + 1] # batch, channels
+ for i in range(ndim):
+ permute_pattern.extend([i + 1, ndim + 2 + i]) # post_patch_sizes[i], patch_sizes[i]
+ hidden_states = hidden_states.permute(*permute_pattern)
+
+ # flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
+ # batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
+ final_dims = [batch_size, out_channels] + [
+ post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
+ ]
+ hidden_states = hidden_states.reshape(*final_dims)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(sample=hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py
new file mode 100644
index 000000000000..316e79da4fd6
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_kandinsky.py
@@ -0,0 +1,669 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import (
+ logging,
+)
+from ..attention import AttentionMixin, AttentionModuleMixin
+from ..attention_dispatch import _CAN_USE_FLEX_ATTN, dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_freqs(dim, max_period=10000.0):
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
+ return freqs
+
+
+def fractal_flatten(x, rope, shape, block_mask=False):
+ if block_mask:
+ pixel_size = 8
+ x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
+ rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
+ x = x.flatten(1, 2)
+ rope = rope.flatten(1, 2)
+ else:
+ x = x.flatten(1, 3)
+ rope = rope.flatten(1, 3)
+ return x, rope
+
+
+def fractal_unflatten(x, shape, block_mask=False):
+ if block_mask:
+ pixel_size = 8
+ x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:])
+ x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
+ else:
+ x = x.reshape(*shape, *x.shape[2:])
+ return x
+
+
+def local_patching(x, shape, group_size, dim=0):
+ batch_size, duration, height, width = shape
+ g1, g2, g3 = group_size
+ x = x.reshape(
+ *x.shape[:dim],
+ duration // g1,
+ g1,
+ height // g2,
+ g2,
+ width // g3,
+ g3,
+ *x.shape[dim + 3 :],
+ )
+ x = x.permute(
+ *range(len(x.shape[:dim])),
+ dim,
+ dim + 2,
+ dim + 4,
+ dim + 1,
+ dim + 3,
+ dim + 5,
+ *range(dim + 6, len(x.shape)),
+ )
+ x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
+ return x
+
+
+def local_merge(x, shape, group_size, dim=0):
+ batch_size, duration, height, width = shape
+ g1, g2, g3 = group_size
+ x = x.reshape(
+ *x.shape[:dim],
+ duration // g1,
+ height // g2,
+ width // g3,
+ g1,
+ g2,
+ g3,
+ *x.shape[dim + 2 :],
+ )
+ x = x.permute(
+ *range(len(x.shape[:dim])),
+ dim,
+ dim + 3,
+ dim + 1,
+ dim + 4,
+ dim + 2,
+ dim + 5,
+ *range(dim + 6, len(x.shape)),
+ )
+ x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
+ return x
+
+
+def nablaT_v2(
+ q: Tensor,
+ k: Tensor,
+ sta: Tensor,
+ thr: float = 0.9,
+):
+ if _CAN_USE_FLEX_ATTN:
+ from torch.nn.attention.flex_attention import BlockMask
+ else:
+ raise ValueError("Nabla attention is not supported with this version of PyTorch")
+
+ q = q.transpose(1, 2).contiguous()
+ k = k.transpose(1, 2).contiguous()
+
+ # Map estimation
+ B, h, S, D = q.shape
+ s1 = S // 64
+ qa = q.reshape(B, h, s1, 64, D).mean(-2)
+ ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
+ map = qa @ ka
+
+ map = torch.softmax(map / math.sqrt(D), dim=-1)
+ # Map binarization
+ vals, inds = map.sort(-1)
+ cvals = vals.cumsum_(-1)
+ mask = (cvals >= 1 - thr).int()
+ mask = mask.gather(-1, inds.argsort(-1))
+
+ mask = torch.logical_or(mask, sta)
+
+ # BlockMask creation
+ kv_nb = mask.sum(-1).to(torch.int32)
+ kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
+ return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None)
+
+
+class Kandinsky5TimeEmbeddings(nn.Module):
+ def __init__(self, model_dim, time_dim, max_period=10000.0):
+ super().__init__()
+ assert model_dim % 2 == 0
+ self.model_dim = model_dim
+ self.max_period = max_period
+ self.freqs = get_freqs(self.model_dim // 2, self.max_period)
+ self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
+ self.activation = nn.SiLU()
+ self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
+
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
+ def forward(self, time):
+ args = torch.outer(time, self.freqs.to(device=time.device))
+ time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
+ return time_embed
+
+
+class Kandinsky5TextEmbeddings(nn.Module):
+ def __init__(self, text_dim, model_dim):
+ super().__init__()
+ self.in_layer = nn.Linear(text_dim, model_dim, bias=True)
+ self.norm = nn.LayerNorm(model_dim, elementwise_affine=True)
+
+ def forward(self, text_embed):
+ text_embed = self.in_layer(text_embed)
+ return self.norm(text_embed).type_as(text_embed)
+
+
+class Kandinsky5VisualEmbeddings(nn.Module):
+ def __init__(self, visual_dim, model_dim, patch_size):
+ super().__init__()
+ self.patch_size = patch_size
+ self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
+
+ def forward(self, x):
+ batch_size, duration, height, width, dim = x.shape
+ x = (
+ x.view(
+ batch_size,
+ duration // self.patch_size[0],
+ self.patch_size[0],
+ height // self.patch_size[1],
+ self.patch_size[1],
+ width // self.patch_size[2],
+ self.patch_size[2],
+ dim,
+ )
+ .permute(0, 1, 3, 5, 2, 4, 6, 7)
+ .flatten(4, 7)
+ )
+ return self.in_layer(x)
+
+
+class Kandinsky5RoPE1D(nn.Module):
+ def __init__(self, dim, max_pos=1024, max_period=10000.0):
+ super().__init__()
+ self.max_period = max_period
+ self.dim = dim
+ self.max_pos = max_pos
+ freq = get_freqs(dim // 2, max_period)
+ pos = torch.arange(max_pos, dtype=freq.dtype)
+ self.register_buffer("args", torch.outer(pos, freq), persistent=False)
+
+ def forward(self, pos):
+ args = self.args[pos]
+ cosine = torch.cos(args)
+ sine = torch.sin(args)
+ rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
+ rope = rope.view(*rope.shape[:-1], 2, 2)
+ return rope.unsqueeze(-4)
+
+
+class Kandinsky5RoPE3D(nn.Module):
+ def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0):
+ super().__init__()
+ self.axes_dims = axes_dims
+ self.max_pos = max_pos
+ self.max_period = max_period
+
+ for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)):
+ freq = get_freqs(axes_dim // 2, max_period)
+ pos = torch.arange(ax_max_pos, dtype=freq.dtype)
+ self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False)
+
+ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
+ batch_size, duration, height, width = shape
+ args_t = self.args_0[pos[0]] / scale_factor[0]
+ args_h = self.args_1[pos[1]] / scale_factor[1]
+ args_w = self.args_2[pos[2]] / scale_factor[2]
+
+ args = torch.cat(
+ [
+ args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1),
+ args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1),
+ args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1),
+ ],
+ dim=-1,
+ )
+ cosine = torch.cos(args)
+ sine = torch.sin(args)
+ rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
+ rope = rope.view(*rope.shape[:-1], 2, 2)
+ return rope.unsqueeze(-4)
+
+
+class Kandinsky5Modulation(nn.Module):
+ def __init__(self, time_dim, model_dim, num_params):
+ super().__init__()
+ self.activation = nn.SiLU()
+ self.out_layer = nn.Linear(time_dim, num_params * model_dim)
+ self.out_layer.weight.data.zero_()
+ self.out_layer.bias.data.zero_()
+
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
+ def forward(self, x):
+ return self.out_layer(self.activation(x))
+
+
+class Kandinsky5AttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None):
+ # query, key, value = self.get_qkv(x)
+ query = attn.to_query(hidden_states)
+
+ if encoder_hidden_states is not None:
+ key = attn.to_key(encoder_hidden_states)
+ value = attn.to_value(encoder_hidden_states)
+
+ shape, cond_shape = query.shape[:-1], key.shape[:-1]
+ query = query.reshape(*shape, attn.num_heads, -1)
+ key = key.reshape(*cond_shape, attn.num_heads, -1)
+ value = value.reshape(*cond_shape, attn.num_heads, -1)
+
+ else:
+ key = attn.to_key(hidden_states)
+ value = attn.to_value(hidden_states)
+
+ shape = query.shape[:-1]
+ query = query.reshape(*shape, attn.num_heads, -1)
+ key = key.reshape(*shape, attn.num_heads, -1)
+ value = value.reshape(*shape, attn.num_heads, -1)
+
+ # query, key = self.norm_qk(query, key)
+ query = attn.query_norm(query.float()).type_as(query)
+ key = attn.key_norm(key.float()).type_as(key)
+
+ def apply_rotary(x, rope):
+ x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
+ x_out = (rope * x_).sum(dim=-1)
+ return x_out.reshape(*x.shape).to(torch.bfloat16)
+
+ if rotary_emb is not None:
+ query = apply_rotary(query, rotary_emb).type_as(query)
+ key = apply_rotary(key, rotary_emb).type_as(key)
+
+ if sparse_params is not None:
+ attn_mask = nablaT_v2(
+ query,
+ key,
+ sparse_params["sta_mask"],
+ thr=sparse_params["P"],
+ )
+
+ else:
+ attn_mask = None
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attn_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ hidden_states = hidden_states.flatten(-2, -1)
+
+ attn_out = attn.out_layer(hidden_states)
+ return attn_out
+
+
+class Kandinsky5Attention(nn.Module, AttentionModuleMixin):
+ _default_processor_cls = Kandinsky5AttnProcessor
+ _available_processors = [
+ Kandinsky5AttnProcessor,
+ ]
+
+ def __init__(self, num_channels, head_dim, processor=None):
+ super().__init__()
+ assert num_channels % head_dim == 0
+ self.num_heads = num_channels // head_dim
+
+ self.to_query = nn.Linear(num_channels, num_channels, bias=True)
+ self.to_key = nn.Linear(num_channels, num_channels, bias=True)
+ self.to_value = nn.Linear(num_channels, num_channels, bias=True)
+ self.query_norm = nn.RMSNorm(head_dim)
+ self.key_norm = nn.RMSNorm(head_dim)
+
+ self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ sparse_params: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ sparse_params=sparse_params,
+ rotary_emb=rotary_emb,
+ **kwargs,
+ )
+
+
+class Kandinsky5FeedForward(nn.Module):
+ def __init__(self, dim, ff_dim):
+ super().__init__()
+ self.in_layer = nn.Linear(dim, ff_dim, bias=False)
+ self.activation = nn.GELU()
+ self.out_layer = nn.Linear(ff_dim, dim, bias=False)
+
+ def forward(self, x):
+ return self.out_layer(self.activation(self.in_layer(x)))
+
+
+class Kandinsky5OutLayer(nn.Module):
+ def __init__(self, model_dim, time_dim, visual_dim, patch_size):
+ super().__init__()
+ self.patch_size = patch_size
+ self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2)
+ self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True)
+
+ def forward(self, visual_embed, text_embed, time_embed):
+ shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
+
+ visual_embed = (
+ self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]
+ ).type_as(visual_embed)
+
+ x = self.out_layer(visual_embed)
+
+ batch_size, duration, height, width, _ = x.shape
+ x = (
+ x.view(
+ batch_size,
+ duration,
+ height,
+ width,
+ -1,
+ self.patch_size[0],
+ self.patch_size[1],
+ self.patch_size[2],
+ )
+ .permute(0, 1, 5, 2, 6, 3, 7, 4)
+ .flatten(1, 2)
+ .flatten(2, 3)
+ .flatten(3, 4)
+ )
+ return x
+
+
+class Kandinsky5TransformerEncoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim):
+ super().__init__()
+ self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6)
+
+ self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
+
+ self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
+
+ def forward(self, x, time_embed, rope):
+ self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
+ shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
+ out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
+ out = self.self_attention(out, rotary_emb=rope)
+ x = (x.float() + gate.float() * out.float()).type_as(x)
+
+ shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
+ out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
+ out = self.feed_forward(out)
+ x = (x.float() + gate.float() * out.float()).type_as(x)
+
+ return x
+
+
+class Kandinsky5TransformerDecoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim):
+ super().__init__()
+ self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
+
+ self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
+
+ self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
+
+ self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
+
+ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
+ self_attn_params, cross_attn_params, ff_params = torch.chunk(
+ self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
+ )
+
+ shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
+ visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
+ visual_embed
+ )
+ visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params)
+ visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
+
+ shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
+ visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
+ visual_embed
+ )
+ visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed)
+ visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
+
+ shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
+ visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
+ visual_embed
+ )
+ visual_out = self.feed_forward(visual_out)
+ visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
+
+ return visual_embed
+
+
+class Kandinsky5Transformer3DModel(
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ CacheMixin,
+ AttentionMixin,
+):
+ """
+ A 3D Diffusion Transformer model for video-like data.
+ """
+
+ _repeated_blocks = [
+ "Kandinsky5TransformerEncoderBlock",
+ "Kandinsky5TransformerDecoderBlock",
+ ]
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_visual_dim=4,
+ in_text_dim=3584,
+ in_text_dim2=768,
+ time_dim=512,
+ out_visual_dim=4,
+ patch_size=(1, 2, 2),
+ model_dim=2048,
+ ff_dim=5120,
+ num_text_blocks=2,
+ num_visual_blocks=32,
+ axes_dims=(16, 24, 24),
+ visual_cond=False,
+ attention_type: str = "regular",
+ attention_causal: bool = None,
+ attention_local: bool = None,
+ attention_glob: bool = None,
+ attention_window: int = None,
+ attention_P: float = None,
+ attention_wT: int = None,
+ attention_wW: int = None,
+ attention_wH: int = None,
+ attention_add_sta: bool = None,
+ attention_method: str = None,
+ ):
+ super().__init__()
+
+ head_dim = sum(axes_dims)
+ self.in_visual_dim = in_visual_dim
+ self.model_dim = model_dim
+ self.patch_size = patch_size
+ self.visual_cond = visual_cond
+ self.attention_type = attention_type
+
+ visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
+
+ # Initialize embeddings
+ self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim)
+ self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim)
+ self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim)
+ self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
+
+ # Initialize positional embeddings
+ self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim)
+ self.visual_rope_embeddings = Kandinsky5RoPE3D(axes_dims)
+
+ # Initialize transformer blocks
+ self.text_transformer_blocks = nn.ModuleList(
+ [Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks)]
+ )
+
+ self.visual_transformer_blocks = nn.ModuleList(
+ [
+ Kandinsky5TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
+ for _ in range(num_visual_blocks)
+ ]
+ )
+
+ # Initialize output layer
+ self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor, # x
+ encoder_hidden_states: torch.Tensor, # text_embed
+ timestep: torch.Tensor, # time
+ pooled_projections: torch.Tensor, # pooled_text_embed
+ visual_rope_pos: Tuple[int, int, int],
+ text_rope_pos: torch.LongTensor,
+ scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0),
+ sparse_params: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Transformer2DModelOutput, torch.FloatTensor]:
+ """
+ Forward pass of the Kandinsky5 3D Transformer.
+
+ Args:
+ hidden_states (`torch.FloatTensor`): Input visual states
+ encoder_hidden_states (`torch.FloatTensor`): Text embeddings
+ timestep (`torch.Tensor` or `float` or `int`): Current timestep
+ pooled_projections (`torch.FloatTensor`): Pooled text embeddings
+ visual_rope_pos (`Tuple[int, int, int]`): Position for visual RoPE
+ text_rope_pos (`torch.LongTensor`): Position for text RoPE
+ scale_factor (`Tuple[float, float, float]`, optional): Scale factor for RoPE
+ sparse_params (`Dict[str, Any]`, optional): Parameters for sparse attention
+ return_dict (`bool`, optional): Whether to return a dictionary
+
+ Returns:
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: The output of the transformer
+ """
+ x = hidden_states
+ text_embed = encoder_hidden_states
+ time = timestep
+ pooled_text_embed = pooled_projections
+
+ text_embed = self.text_embeddings(text_embed)
+ time_embed = self.time_embeddings(time)
+ time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
+ visual_embed = self.visual_embeddings(x)
+ text_rope = self.text_rope_embeddings(text_rope_pos)
+ text_rope = text_rope.unsqueeze(dim=0)
+
+ for text_transformer_block in self.text_transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ text_embed = self._gradient_checkpointing_func(
+ text_transformer_block, text_embed, time_embed, text_rope
+ )
+ else:
+ text_embed = text_transformer_block(text_embed, time_embed, text_rope)
+
+ visual_shape = visual_embed.shape[:-1]
+ visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
+ to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
+ visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, block_mask=to_fractal)
+
+ for visual_transformer_block in self.visual_transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ visual_embed = self._gradient_checkpointing_func(
+ visual_transformer_block,
+ visual_embed,
+ text_embed,
+ time_embed,
+ visual_rope,
+ sparse_params,
+ )
+ else:
+ visual_embed = visual_transformer_block(
+ visual_embed, text_embed, time_embed, visual_rope, sparse_params
+ )
+
+ visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal)
+ x = self.out_layer(visual_embed, text_embed, time_embed)
+
+ if not return_dict:
+ return x
+
+ return Transformer2DModelOutput(sample=x)
diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py
index 2ae2418098f6..685c73c07c75 100644
--- a/src/diffusers/models/transformers/transformer_ltx.py
+++ b/src/diffusers/models/transformers/transformer_ltx.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The Genmo team and The HuggingFace Team.
+# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
-from ..attention_processor import Attention
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
@@ -37,20 +38,31 @@
class LTXVideoAttentionProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
+ deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
+
+ return LTXVideoAttnProcessor(*args, **kwargs)
+
+
+class LTXVideoAttnProcessor:
r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
- used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
+ Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
+ model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
+ _attention_backend = None
+ _parallel_config = None
+
def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ if is_torch_version("<", "2.0"):
+ raise ValueError(
+ "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
- attn: Attention,
+ attn: "LTXAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
@@ -78,14 +90,21 @@ def __call__(
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
@@ -93,6 +112,70 @@ def __call__(
return hidden_states
+class LTXAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = LTXVideoAttnProcessor
+ _available_processors = [LTXVideoAttnProcessor]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ kv_heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = True,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ qk_norm: str = "rms_norm_across_heads",
+ processor=None,
+ ):
+ super().__init__()
+ if qk_norm != "rms_norm_across_heads":
+ raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
+
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = query_dim
+ self.heads = heads
+
+ norm_eps = 1e-5
+ norm_elementwise_affine = True
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
@@ -231,7 +314,7 @@ def __init__(
super().__init__()
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
- self.attn1 = Attention(
+ self.attn1 = LTXAttention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
@@ -240,11 +323,10 @@ def __init__(
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
- processor=LTXVideoAttentionProcessor2_0(),
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
- self.attn2 = Attention(
+ self.attn2 = LTXAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
@@ -253,7 +335,6 @@ def __init__(
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
- processor=LTXVideoAttentionProcessor2_0(),
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -272,7 +353,9 @@ def forward(
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
- ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
+ ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
+ batch_size, temb.size(1), num_ada_params, -1
+ )
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
@@ -299,7 +382,9 @@ def forward(
@maybe_allow_in_graph
-class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
+class LTXVideoTransformer3DModel(
+ ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
+):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -328,6 +413,19 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
+ _repeated_blocks = ["LTXVideoTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
@@ -481,7 +579,7 @@ def forward(
def apply_rotary_emb(x, freqs):
cos, sin = freqs
- x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
+ x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py
index a873a6ec9444..77121edb9fc9 100644
--- a/src/diffusers/models/transformers/transformer_lumina2.py
+++ b/src/diffusers/models/transformers/transformer_lumina2.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py
index e6532f080d72..63911fe7c10d 100644
--- a/src/diffusers/models/transformers/transformer_mochi.py
+++ b/src/diffusers/models/transformers/transformer_mochi.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The Genmo team and The HuggingFace Team.
+# Copyright 2025 The Genmo team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/src/diffusers/models/transformers/transformer_omnigen.py b/src/diffusers/models/transformers/transformer_omnigen.py
index 8d5d1b3f8fea..6939cac0a3a7 100644
--- a/src/diffusers/models/transformers/transformer_omnigen.py
+++ b/src/diffusers/models/transformers/transformer_omnigen.py
@@ -1,4 +1,4 @@
-# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -283,7 +283,7 @@ def forward(
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
"""
- The Transformer model introduced in OmniGen (https://arxiv.org/pdf/2409.11340).
+ The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340).
Parameters:
in_channels (`int`, defaults to `4`):
diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py
new file mode 100644
index 000000000000..0a09aa720b3f
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_ovis_image.py
@@ -0,0 +1,581 @@
+# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import is_torch_npu_available, logging
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_projections(attn: "OvisImageAttention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "OvisImageAttention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "OvisImageAttention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+class OvisImageAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "OvisImageAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class OvisImageAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = OvisImageAttnProcessor
+ _available_processors = [
+ OvisImageAttnProcessor,
+ ]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ context_pre_only: Optional[bool] = None,
+ pre_only: bool = False,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+@maybe_allow_in_graph
+class OvisImageSingleTransformerBlock(nn.Module):
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm = AdaLayerNormZeroSingle(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim * 2)
+ self.act_mlp = nn.SiLU()
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ self.attn = OvisImageAttention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=OvisImageAttnProcessor(),
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states, mlp_hidden_gate = torch.split(
+ self.proj_mlp(norm_hidden_states), [self.mlp_hidden_dim, self.mlp_hidden_dim], dim=-1
+ )
+ mlp_hidden_states = self.act_mlp(mlp_hidden_gate) * mlp_hidden_states
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+
+
+@maybe_allow_in_graph
+class OvisImageTransformerBlock(nn.Module):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
+ super().__init__()
+
+ self.norm1 = AdaLayerNormZero(dim)
+ self.norm1_context = AdaLayerNormZero(dim)
+
+ self.attn = OvisImageAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=OvisImageAttnProcessor(),
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class OvisImagePosEmbed(nn.Module):
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+class OvisImageTransformer2DModel(
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ CacheMixin,
+):
+ """
+ The Transformer model introduced in Ovis-Image.
+
+ Reference: https://github.com/AIDC-AI/Ovis-Image
+
+ Args:
+ patch_size (`int`, defaults to `1`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `64`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `6`):
+ The number of layers of dual stream DiT blocks to use.
+ num_single_layers (`int`, defaults to `27`):
+ The number of layers of single stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of attention heads to use.
+ joint_attention_dim (`int`, defaults to `2048`):
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
+ `encoder_hidden_states`).
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions to use for the rotary positional embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+ _repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ out_channels: Optional[int] = 64,
+ num_layers: int = 6,
+ num_single_layers: int = 27,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 2048,
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = OvisImagePosEmbed(theta=10000, axes_dim=axes_dims_rope)
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim)
+
+ self.context_embedder_norm = nn.RMSNorm(joint_attention_dim, eps=1e-6)
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ OvisImageTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ OvisImageSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`OvisImageTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ timestep (`torch.LongTensor`):
+ Used to indicate denoising step.
+ img_ids: (`torch.Tensor`):
+ The position ids for image tokens.
+ txt_ids (`torch.Tensor`):
+ The position ids for text tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype) * 1000
+
+ timesteps_proj = self.time_proj(timestep)
+ temb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
+
+ encoder_hidden_states = self.context_embedder_norm(encoder_hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if txt_ids.ndim == 3:
+ logger.warning(
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ txt_ids = txt_ids[0]
+ if img_ids.ndim == 3:
+ logger.warning(
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
+ )
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ if is_torch_npu_available():
+ freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
+ image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
+ else:
+ image_rotary_emb = self.pos_embed(ids)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py
new file mode 100644
index 000000000000..a87c120fdcd7
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_prx.py
@@ -0,0 +1,801 @@
+# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ..attention import AttentionMixin, AttentionModuleMixin
+from ..attention_dispatch import dispatch_attention_fn
+from ..embeddings import get_timestep_embedding
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import RMSNorm
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
+ r"""
+ Generates 2D patch coordinate indices for a batch of images.
+
+ Args:
+ batch_size (`int`):
+ Number of images in the batch.
+ height (`int`):
+ Height of the input images (in pixels).
+ width (`int`):
+ Width of the input images (in pixels).
+ patch_size (`int`):
+ Size of the square patches that the image is divided into.
+ device (`torch.device`):
+ The device on which to create the tensor.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
+ image grid.
+ """
+
+ img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
+ img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
+ img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
+ return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
+
+
+def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ r"""
+ Applies rotary positional embeddings (RoPE) to a query tensor.
+
+ Args:
+ xq (`torch.Tensor`):
+ Input tensor of shape `(..., dim)` representing the queries.
+ freqs_cis (`torch.Tensor`):
+ Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of the same shape as `xq` with rotary embeddings applied.
+ """
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ # Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
+ freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq)
+
+
+class PRXAttnProcessor2_0:
+ r"""
+ Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
+ backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
+ raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: "PRXAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Apply PRX attention using PRXAttention module.
+
+ Args:
+ attn: PRXAttention module containing projection layers
+ hidden_states: Image tokens [B, L_img, D]
+ encoder_hidden_states: Text tokens [B, L_txt, D]
+ attention_mask: Boolean mask for text tokens [B, L_txt]
+ image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
+ """
+
+ if encoder_hidden_states is None:
+ raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
+
+ # Project image tokens to Q, K, V
+ img_qkv = attn.img_qkv_proj(hidden_states)
+ B, L_img, _ = img_qkv.shape
+ img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim)
+ img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D]
+ img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
+
+ # Apply QK normalization to image tokens
+ img_q = attn.norm_q(img_q)
+ img_k = attn.norm_k(img_k)
+
+ # Project text tokens to K, V
+ txt_kv = attn.txt_kv_proj(encoder_hidden_states)
+ B, L_txt, _ = txt_kv.shape
+ txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim)
+ txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D]
+ txt_k, txt_v = txt_kv[0], txt_kv[1]
+
+ # Apply K normalization to text tokens
+ txt_k = attn.norm_added_k(txt_k)
+
+ # Apply RoPE to image queries and keys
+ if image_rotary_emb is not None:
+ img_q = apply_rope(img_q, image_rotary_emb)
+ img_k = apply_rope(img_k, image_rotary_emb)
+
+ # Concatenate text and image keys/values
+ k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D]
+ v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D]
+
+ # Build attention mask if provided
+ attn_mask_tensor = None
+ if attention_mask is not None:
+ bs, _, l_img, _ = img_q.shape
+ l_txt = txt_k.shape[2]
+
+ if attention_mask.dim() != 2:
+ raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
+ if attention_mask.shape[-1] != l_txt:
+ raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")
+
+ device = img_q.device
+ ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
+ attention_mask = attention_mask.to(device=device, dtype=torch.bool)
+ joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
+ attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
+
+ # Apply attention using dispatch_attention_fn for backend support
+ # Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
+ query = img_q.transpose(1, 2) # [B, L_img, H, D]
+ key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
+ value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
+
+ attn_output = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attn_mask_tensor,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ # Reshape from [B, L_img, H, D] to [B, L_img, H*D]
+ batch_size, seq_len, num_heads, head_dim = attn_output.shape
+ attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
+
+ # Apply output projection
+ attn_output = attn.to_out[0](attn_output)
+ if len(attn.to_out) > 1:
+ attn_output = attn.to_out[1](attn_output) # dropout if present
+
+ return attn_output
+
+
+class PRXAttention(nn.Module, AttentionModuleMixin):
+ r"""
+ PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
+ PRX's architecture.
+ """
+
+ _default_processor_cls = PRXAttnProcessor2_0
+ _available_processors = [PRXAttnProcessor2_0]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ bias: bool = False,
+ out_bias: bool = False,
+ eps: float = 1e-6,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.heads = heads
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.query_dim = query_dim
+
+ self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias)
+
+ self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
+ self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
+
+ self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias)
+ self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(0.0))
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ **kwargs,
+ )
+
+
+# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
+class PRXEmbedND(nn.Module):
+ r"""
+ N-dimensional rotary positional embedding.
+
+ This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
+ dimension. The embeddings are combined and returned as a single tensor
+
+ Args:
+ dim (int):
+ Base embedding dimension (must be even).
+ theta (int):
+ Scaling factor that controls the frequency spectrum of the rotary embeddings.
+ axes_dim (list[int]):
+ List of embedding dimensions for each axis (each must be even).
+ """
+
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ assert dim % 2 == 0
+
+ is_mps = pos.device.type == "mps"
+ is_npu = pos.device.type == "npu"
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+
+ scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+ out = pos.unsqueeze(-1) * omega.unsqueeze(0)
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
+ # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
+ # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
+ out = out.reshape(*out.shape[:-1], 2, 2)
+ return out.float()
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+ return emb.unsqueeze(1)
+
+
+class MLPEmbedder(nn.Module):
+ r"""
+ A simple 2-layer MLP used for embedding inputs.
+
+ Args:
+ in_dim (`int`):
+ Dimensionality of the input features.
+ hidden_dim (`int`):
+ Dimensionality of the hidden and output embedding space.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of shape `(..., hidden_dim)` containing the embedded representations.
+ """
+
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class Modulation(nn.Module):
+ r"""
+ Modulation network that generates scale, shift, and gating parameters.
+
+ Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
+ two tuples `(shift, scale, gate)`.
+
+ Args:
+ dim (`int`):
+ Dimensionality of the input vector. The output will have `6 * dim` features internally.
+
+ Returns:
+ ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
+ Two tuples `(shift, scale, gate)`.
+ """
+
+ def __init__(self, dim: int):
+ super().__init__()
+ self.lin = nn.Linear(dim, 6 * dim, bias=True)
+ nn.init.constant_(self.lin.weight, 0)
+ nn.init.constant_(self.lin.bias, 0)
+
+ def forward(
+ self, vec: torch.Tensor
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
+ return tuple(out[:3]), tuple(out[3:])
+
+
+class PRXBlock(nn.Module):
+ r"""
+ Multimodal transformer block with text–image cross-attention, modulation, and MLP.
+
+ Args:
+ hidden_size (`int`):
+ Dimension of the hidden representations.
+ num_heads (`int`):
+ Number of attention heads.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Expansion ratio for the hidden dimension inside the MLP.
+ qk_scale (`float`, *optional*):
+ Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.
+
+ Attributes:
+ img_pre_norm (`nn.LayerNorm`):
+ Pre-normalization applied to image tokens before attention.
+ attention (`PRXAttention`):
+ Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
+ image and text tokens.
+ post_attention_layernorm (`nn.LayerNorm`):
+ Normalization applied after attention.
+ gate_proj / up_proj / down_proj (`nn.Linear`):
+ Feedforward layers forming the gated MLP.
+ mlp_act (`nn.GELU`):
+ Nonlinear activation used in the MLP.
+ modulation (`Modulation`):
+ Produces scale/shift/gating parameters for modulated layers.
+
+ Methods:
+ The forward method performs cross-attention and the MLP with modulation.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: Optional[float] = None,
+ ):
+ super().__init__()
+
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.hidden_size = hidden_size
+
+ # Pre-attention normalization for image tokens
+ self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ # PRXAttention module with built-in projections and norms
+ self.attention = PRXAttention(
+ query_dim=hidden_size,
+ heads=num_heads,
+ dim_head=self.head_dim,
+ bias=False,
+ out_bias=False,
+ eps=1e-6,
+ processor=PRXAttnProcessor2_0(),
+ )
+
+ # mlp
+ self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
+ self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
+ self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
+ self.mlp_act = nn.GELU(approximate="tanh")
+
+ self.modulation = Modulation(hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Dict[str, Any],
+ ) -> torch.Tensor:
+ r"""
+ Runs modulation-gated cross-attention and MLP, with residual connections.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Image tokens of shape `(B, L_img, hidden_size)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Text tokens of shape `(B, L_txt, hidden_size)`.
+ temb (`torch.Tensor`):
+ Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
+ broadcastable).
+ image_rotary_emb (`torch.Tensor`):
+ Rotary positional embeddings applied inside attention.
+ attention_mask (`torch.Tensor`, *optional*):
+ Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
+ **kwargs:
+ Additional keyword arguments for API compatibility.
+
+ Returns:
+ `torch.Tensor`:
+ Updated image tokens of shape `(B, L_img, hidden_size)`.
+ """
+
+ mod_attn, mod_mlp = self.modulation(temb)
+ attn_shift, attn_scale, attn_gate = mod_attn
+ mlp_shift, mlp_scale, mlp_gate = mod_mlp
+
+ hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
+
+ attn_out = self.attention(
+ hidden_states=hidden_states_mod,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + attn_gate * attn_out
+
+ x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
+ hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
+ return hidden_states
+
+
+class FinalLayer(nn.Module):
+ r"""
+ Final projection layer with adaptive LayerNorm modulation.
+
+ This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
+ outputs.
+
+ Args:
+ hidden_size (`int`):
+ Dimensionality of the input tokens.
+ patch_size (`int`):
+ Size of the square image patches.
+ out_channels (`int`):
+ Number of output channels per pixel (e.g. RGB = 3).
+
+ Forward Inputs:
+ x (`torch.Tensor`):
+ Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
+ vec (`torch.Tensor`):
+ Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
+ LayerNorm.
+
+ Returns:
+ `torch.Tensor`:
+ Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
+ """
+
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
+
+
+def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
+ r"""
+ Flattens an image tensor into a sequence of non-overlapping patches.
+
+ Args:
+ img (`torch.Tensor`):
+ Input image tensor of shape `(B, C, H, W)`.
+ patch_size (`int`):
+ Size of each square patch. Must evenly divide both `H` and `W`.
+
+ Returns:
+ `torch.Tensor`:
+ Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
+ // patch_size)` is the number of patches.
+ """
+ b, c, h, w = img.shape
+ p = patch_size
+
+ # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
+ img = img.reshape(b, c, h // p, p, w // p, p)
+
+ # Permute to (B, H//p, W//p, C, p, p) using einsum
+ # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
+ img = torch.einsum("nchpwq->nhwcpq", img)
+
+ # Flatten to (B, L, C * p * p)
+ img = img.reshape(b, -1, c * p * p)
+ return img
+
+
+def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
+ r"""
+ Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
+
+ Args:
+ seq (`torch.Tensor`):
+ Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
+ patch_size)`.
+ patch_size (`int`):
+ Size of each square patch.
+ shape (`tuple` or `torch.Tensor`):
+ The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
+ height and width.
+
+ Returns:
+ `torch.Tensor`:
+ Reconstructed image tensor of shape `(B, C, H, W)`.
+ """
+ if isinstance(shape, tuple):
+ h, w = shape[-2:]
+ elif isinstance(shape, torch.Tensor):
+ h, w = (int(shape[0]), int(shape[1]))
+ else:
+ raise NotImplementedError(f"shape type {type(shape)} not supported")
+
+ b, l, d = seq.shape
+ p = patch_size
+ c = d // (p * p)
+
+ # Reshape back to grid structure: (B, H//p, W//p, C, p, p)
+ seq = seq.reshape(b, h // p, w // p, c, p, p)
+
+ # Permute back to image layout: (B, C, H//p, p, W//p, p)
+ # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
+ seq = torch.einsum("nhwcpq->nchpwq", seq)
+
+ # Final reshape to (B, C, H, W)
+ seq = seq.reshape(b, c, h, w)
+ return seq
+
+
+class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
+ r"""
+ Transformer-based 2D model for text to image generation.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 16):
+ Number of input channels in the latent image.
+ patch_size (`int`, *optional*, defaults to 2):
+ Size of the square patches used to flatten the input image.
+ context_in_dim (`int`, *optional*, defaults to 2304):
+ Dimensionality of the text conditioning input.
+ hidden_size (`int`, *optional*, defaults to 1792):
+ Dimension of the hidden representation.
+ mlp_ratio (`float`, *optional*, defaults to 3.5):
+ Expansion ratio for the hidden dimension inside MLP blocks.
+ num_heads (`int`, *optional*, defaults to 28):
+ Number of attention heads.
+ depth (`int`, *optional*, defaults to 16):
+ Number of transformer blocks.
+ axes_dim (`list[int]`, *optional*):
+ List of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
+ theta (`int`, *optional*, defaults to 10000):
+ Frequency scaling factor for rotary embeddings.
+ time_factor (`float`, *optional*, defaults to 1000.0):
+ Scaling factor applied in timestep embeddings.
+ time_max_period (`int`, *optional*, defaults to 10000):
+ Maximum frequency period for timestep embeddings.
+
+ Attributes:
+ pe_embedder (`EmbedND`):
+ Multi-axis rotary embedding generator for positional encodings.
+ img_in (`nn.Linear`):
+ Projection layer for image patch tokens.
+ time_in (`MLPEmbedder`):
+ Embedding layer for timestep embeddings.
+ txt_in (`nn.Linear`):
+ Projection layer for text conditioning.
+ blocks (`nn.ModuleList`):
+ Stack of transformer blocks (`PRXBlock`).
+ final_layer (`LastLayer`):
+ Projection layer mapping hidden tokens back to patch outputs.
+
+ Methods:
+ attn_processors:
+ Returns a dictionary of all attention processors in the model.
+ set_attn_processor(processor):
+ Replaces attention processors across all attention layers.
+ process_inputs(image_latent, txt):
+ Converts inputs into patch tokens, encodes text, and produces positional encodings.
+ compute_timestep_embedding(timestep, dtype):
+ Creates a timestep embedding of dimension 256, scaled and projected.
+ forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
+ **block_kwargs):
+ Runs the sequence of transformer blocks over image and text tokens.
+ forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
+ attention_kwargs=None, return_dict=True):
+ Full forward pass from latent input to reconstructed output image.
+
+ Returns:
+ `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
+ - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
+ """
+
+ config_name = "config.json"
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ patch_size: int = 2,
+ context_in_dim: int = 2304,
+ hidden_size: int = 1792,
+ mlp_ratio: float = 3.5,
+ num_heads: int = 28,
+ depth: int = 16,
+ axes_dim: list = None,
+ theta: int = 10000,
+ time_factor: float = 1000.0,
+ time_max_period: int = 10000,
+ ):
+ super().__init__()
+
+ if axes_dim is None:
+ axes_dim = [32, 32]
+
+ # Store parameters directly
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.out_channels = self.in_channels * self.patch_size**2
+
+ self.time_factor = time_factor
+ self.time_max_period = time_max_period
+
+ if hidden_size % num_heads != 0:
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
+
+ pe_dim = hidden_size // num_heads
+
+ if sum(axes_dim) != pe_dim:
+ raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
+
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
+ self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
+
+ self.blocks = nn.ModuleList(
+ [
+ PRXBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=mlp_ratio,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels)
+
+ self.gradient_checkpointing = False
+
+ def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ return self.time_in(
+ get_timestep_embedding(
+ timesteps=timestep,
+ embedding_dim=256,
+ max_period=self.time_max_period,
+ scale=self.time_factor,
+ flip_sin_to_cos=True, # Match original cos, sin order
+ downscale_freq_shift=0.0,
+ ).to(dtype)
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
+ r"""
+ Forward pass of the PRXTransformer2DModel.
+
+ The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
+ transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input latent image tensor of shape `(B, C, H, W)`.
+ timestep (`torch.Tensor`):
+ Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
+ encoder_hidden_states (`torch.Tensor`):
+ Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
+ attention_mask (`torch.Tensor`, *optional*):
+ Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
+ attention_kwargs (`dict`, *optional*):
+ Additional arguments passed to attention layers.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a `Transformer2DModelOutput` or a tuple.
+
+ Returns:
+ `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:
+
+ - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
+ """
+ # Process text conditioning
+ txt = self.txt_in(encoder_hidden_states)
+
+ # Convert image to sequence and embed
+ img = img2seq(hidden_states, self.patch_size)
+ img = self.img_in(img)
+
+ # Generate positional embeddings
+ bs, _, h, w = hidden_states.shape
+ img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
+ pe = self.pe_embedder(img_ids)
+
+ # Compute time embedding
+ vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
+
+ # Apply transformer blocks
+ for block in self.blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ img = self._gradient_checkpointing_func(
+ block.__call__,
+ img,
+ txt,
+ vec,
+ pe,
+ attention_mask,
+ )
+ else:
+ img = block(
+ hidden_states=img,
+ encoder_hidden_states=txt,
+ temb=vec,
+ image_rotary_emb=pe,
+ attention_mask=attention_mask,
+ )
+
+ # Final layer and convert back to image
+ img = self.final_layer(img, vec)
+ output = seq2img(img, self.patch_size, hidden_states.shape)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py
new file mode 100644
index 000000000000..c0fa031b9faf
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_qwenimage.py
@@ -0,0 +1,673 @@
+# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import TimestepEmbedding, Timesteps
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+) -> torch.Tensor:
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ Args
+ timesteps (torch.Tensor):
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ embedding_dim (int):
+ the dimension of the output.
+ flip_sin_to_cos (bool):
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
+ downscale_freq_shift (float):
+ Controls the delta between frequencies between dimensions
+ scale (float):
+ Scaling factor applied to the embeddings.
+ max_period (int):
+ Controls the maximum frequency of the embeddings
+ Returns
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent).to(timesteps.dtype)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def apply_rotary_emb_qwen(
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ use_real: bool = True,
+ use_real_unbind_dim: int = -1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ if use_real:
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ if use_real_unbind_dim == -1:
+ # Used for flux, cogvideox, hunyuan-dit
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ elif use_real_unbind_dim == -2:
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
+ else:
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
+
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+ return out
+ else:
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(1)
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
+
+ return x_out.type_as(x)
+
+
+class QwenTimestepProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timestep, hidden_states):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
+
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+class QwenEmbedRope(nn.Module):
+ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+ pos_index = torch.arange(4096)
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
+ self.pos_freqs = torch.cat(
+ [
+ self.rope_params(pos_index, self.axes_dim[0], self.theta),
+ self.rope_params(pos_index, self.axes_dim[1], self.theta),
+ self.rope_params(pos_index, self.axes_dim[2], self.theta),
+ ],
+ dim=1,
+ )
+ self.neg_freqs = torch.cat(
+ [
+ self.rope_params(neg_index, self.axes_dim[0], self.theta),
+ self.rope_params(neg_index, self.axes_dim[1], self.theta),
+ self.rope_params(neg_index, self.axes_dim[2], self.theta),
+ ],
+ dim=1,
+ )
+
+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
+ self.scale_rope = scale_rope
+
+ def rope_params(self, index, dim, theta=10000):
+ """
+ Args:
+ index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
+ """
+ assert dim % 2 == 0
+ freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs
+
+ def forward(
+ self,
+ video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
+ txt_seq_lens: List[int],
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
+ A list of 3 integers [frame, height, width] representing the shape of the video.
+ txt_seq_lens (`List[int]`):
+ A list of integers of length batch_size representing the length of each text prompt.
+ device: (`torch.device`):
+ The device on which to perform the RoPE computation.
+ """
+ if self.pos_freqs.device != device:
+ self.pos_freqs = self.pos_freqs.to(device)
+ self.neg_freqs = self.neg_freqs.to(device)
+
+ if isinstance(video_fhw, list):
+ video_fhw = video_fhw[0]
+ if not isinstance(video_fhw, list):
+ video_fhw = [video_fhw]
+
+ vid_freqs = []
+ max_vid_index = 0
+ for idx, fhw in enumerate(video_fhw):
+ frame, height, width = fhw
+ # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
+ video_freq = video_freq.to(device)
+ vid_freqs.append(video_freq)
+
+ if self.scale_rope:
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
+ else:
+ max_vid_index = max(height, width, max_vid_index)
+
+ max_len = max(txt_seq_lens)
+ txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
+ vid_freqs = torch.cat(vid_freqs, dim=0)
+
+ return vid_freqs, txt_freqs
+
+ @functools.lru_cache(maxsize=128)
+ def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
+ seq_lens = frame * height * width
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
+ if self.scale_rope:
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
+ else:
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
+
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
+ return freqs.clone().contiguous()
+
+
+class QwenDoubleStreamAttnProcessor2_0:
+ """
+ Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
+ implements joint attention computation where text and image streams are processed together.
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor, # Image stream
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
+ encoder_hidden_states_mask: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ if encoder_hidden_states is None:
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
+
+ seq_txt = encoder_hidden_states.shape[1]
+
+ # Compute QKV for image stream (sample projections)
+ img_query = attn.to_q(hidden_states)
+ img_key = attn.to_k(hidden_states)
+ img_value = attn.to_v(hidden_states)
+
+ # Compute QKV for text stream (context projections)
+ txt_query = attn.add_q_proj(encoder_hidden_states)
+ txt_key = attn.add_k_proj(encoder_hidden_states)
+ txt_value = attn.add_v_proj(encoder_hidden_states)
+
+ # Reshape for multi-head attention
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
+
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
+
+ # Apply QK normalization
+ if attn.norm_q is not None:
+ img_query = attn.norm_q(img_query)
+ if attn.norm_k is not None:
+ img_key = attn.norm_k(img_key)
+ if attn.norm_added_q is not None:
+ txt_query = attn.norm_added_q(txt_query)
+ if attn.norm_added_k is not None:
+ txt_key = attn.norm_added_k(txt_key)
+
+ # Apply RoPE
+ if image_rotary_emb is not None:
+ img_freqs, txt_freqs = image_rotary_emb
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
+
+ # Concatenate for joint attention
+ # Order: [text, image]
+ joint_query = torch.cat([txt_query, img_query], dim=1)
+ joint_key = torch.cat([txt_key, img_key], dim=1)
+ joint_value = torch.cat([txt_value, img_value], dim=1)
+
+ # Compute joint attention
+ joint_hidden_states = dispatch_attention_fn(
+ joint_query,
+ joint_key,
+ joint_value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ # Reshape back
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
+
+ # Split attention outputs back
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
+
+ # Apply output projections
+ img_attn_output = attn.to_out[0](img_attn_output)
+ if len(attn.to_out) > 1:
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
+
+ txt_attn_output = attn.to_add_out(txt_attn_output)
+
+ return img_attn_output, txt_attn_output
+
+
+@maybe_allow_in_graph
+class QwenImageTransformerBlock(nn.Module):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
+ super().__init__()
+
+ self.dim = dim
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+
+ # Image processing modules
+ self.img_mod = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
+ )
+ self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None, # Enable cross attention for joint computation
+ added_kv_proj_dim=dim, # Enable added KV projections for text stream
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=QwenDoubleStreamAttnProcessor2_0(),
+ qk_norm=qk_norm,
+ eps=eps,
+ )
+ self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ # Text processing modules
+ self.txt_mod = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
+ )
+ self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ # Text doesn't need separate attention - it's handled by img_attn joint computation
+ self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
+ self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def _modulate(self, x, mod_params):
+ """Apply modulation to input tensor"""
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_mask: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Get modulation parameters for both streams
+ img_mod_params = self.img_mod(temb) # [B, 6*dim]
+ txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
+
+ # Split modulation parameters for norm1 and norm2
+ img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
+ txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
+
+ # Process image stream - norm1 + modulation
+ img_normed = self.img_norm1(hidden_states)
+ img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
+
+ # Process text stream - norm1 + modulation
+ txt_normed = self.txt_norm1(encoder_hidden_states)
+ txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
+
+ # Use QwenAttnProcessor2_0 for joint attention computation
+ # This directly implements the DoubleStreamLayerMegatron logic:
+ # 1. Computes QKV for both streams
+ # 2. Applies QK normalization and RoPE
+ # 3. Concatenates and runs joint attention
+ # 4. Splits results back to separate streams
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=img_modulated, # Image stream (will be processed as "sample")
+ encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
+ img_attn_output, txt_attn_output = attn_output
+
+ # Apply attention gates and add residual (like in Megatron)
+ hidden_states = hidden_states + img_gate1 * img_attn_output
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
+
+ # Process image stream - norm2 + MLP
+ img_normed2 = self.img_norm2(hidden_states)
+ img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
+ img_mlp_output = self.img_mlp(img_modulated2)
+ hidden_states = hidden_states + img_gate2 * img_mlp_output
+
+ # Process text stream - norm2 + MLP
+ txt_normed2 = self.txt_norm2(encoder_hidden_states)
+ txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
+ txt_mlp_output = self.txt_mlp(txt_modulated2)
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
+
+ # Clip to prevent overflow for fp16
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class QwenImageTransformer2DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ """
+ The Transformer model introduced in Qwen.
+
+ Args:
+ patch_size (`int`, defaults to `2`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `64`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `60`):
+ The number of layers of dual stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of attention heads to use.
+ joint_attention_dim (`int`, defaults to `3584`):
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
+ `encoder_hidden_states`).
+ guidance_embeds (`bool`, defaults to `False`):
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions to use for the rotary positional embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["QwenImageTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+ _repeated_blocks = ["QwenImageTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ "pos_embed": {
+ 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 64,
+ out_channels: Optional[int] = 16,
+ num_layers: int = 60,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 3584,
+ guidance_embeds: bool = False, # TODO: this should probably be removed
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
+
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
+
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
+
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ QwenImageTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ encoder_hidden_states_mask: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
+ txt_seq_lens: Optional[List[int]] = None,
+ guidance: torch.Tensor = None, # TODO: this should probably be removed
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ """
+ The [`QwenTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
+ Mask of the input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ hidden_states = self.img_in(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype)
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype) * 1000
+
+ temb = (
+ self.time_text_embed(timestep, hidden_states)
+ if guidance is None
+ else self.time_text_embed(timestep, guidance, hidden_states)
+ )
+
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ encoder_hidden_states_mask,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=attention_kwargs,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
+ # Use only the image part (hidden_states) from the dual-stream blocks
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py
new file mode 100644
index 000000000000..a4f90342631a
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_sana_video.py
@@ -0,0 +1,705 @@
+# Copyright 2025 The HuggingFace Team and SANA-Video Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormSingle, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class GLUMBTempConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expand_ratio: float = 4,
+ norm_type: Optional[str] = None,
+ residual_connection: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_channels = int(expand_ratio * in_channels)
+ self.norm_type = norm_type
+ self.residual_connection = residual_connection
+
+ self.nonlinearity = nn.SiLU()
+ self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
+ self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
+ self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
+
+ self.norm = None
+ if norm_type == "rms_norm":
+ self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
+
+ self.conv_temp = nn.Conv2d(
+ out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.residual_connection:
+ residual = hidden_states
+ batch_size, num_frames, height, width, num_channels = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size * num_frames, height, width, num_channels).permute(0, 3, 1, 2)
+
+ hidden_states = self.conv_inverted(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv_depth(hidden_states)
+ hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
+ hidden_states = hidden_states * self.nonlinearity(gate)
+
+ hidden_states = self.conv_point(hidden_states)
+
+ # Temporal aggregation
+ hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(
+ 0, 2, 1, 3
+ )
+ hidden_states = hidden_states_temporal + self.conv_temp(hidden_states_temporal)
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels)
+
+ if self.norm_type == "rms_norm":
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.residual_connection:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class SanaLinearAttnProcessor3_0:
+ r"""
+ Processor for implementing scaled dot-product linear attention.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ original_dtype = hidden_states.dtype
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+ # B,N,H,C
+
+ query = F.relu(query)
+ key = F.relu(key)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query_rotate = apply_rotary_emb(query, *rotary_emb)
+ key_rotate = apply_rotary_emb(key, *rotary_emb)
+
+ # B,H,C,N
+ query = query.permute(0, 2, 3, 1)
+ key = key.permute(0, 2, 3, 1)
+ query_rotate = query_rotate.permute(0, 2, 3, 1)
+ key_rotate = key_rotate.permute(0, 2, 3, 1)
+ value = value.permute(0, 2, 3, 1)
+
+ query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float()
+
+ z = 1 / (key.sum(dim=-1, keepdim=True).transpose(-2, -1) @ query + 1e-15)
+
+ scores = torch.matmul(value, key_rotate.transpose(-1, -2))
+ hidden_states = torch.matmul(scores, query_rotate)
+
+ hidden_states = hidden_states * z
+ # B,H,C,N
+ hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
+ hidden_states = hidden_states.to(original_dtype)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class WanRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+
+ self.t_dim = t_dim
+ self.h_dim = h_dim
+ self.w_dim = w_dim
+
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+class SanaModulatedNorm(nn.Module):
+ def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
+ ) -> torch.Tensor:
+ hidden_states = self.norm(hidden_states)
+ shift, scale = (scale_shift_table[None, None] + temb[:, :, None].to(scale_shift_table.device)).unbind(dim=2)
+ hidden_states = hidden_states * (1 + scale) + shift
+ return hidden_states
+
+
+class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
+ def __init__(self, embedding_dim):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ guidance_proj = self.guidance_condition_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
+ conditioning = timesteps_emb + guidance_emb
+
+ return self.linear(self.silu(conditioning)), conditioning
+
+
+class SanaAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SanaVideoTransformerBlock(nn.Module):
+ r"""
+ Transformer block introduced in [Sana-Video](https://huggingface.co/papers/2509.24695).
+ """
+
+ def __init__(
+ self,
+ dim: int = 2240,
+ num_attention_heads: int = 20,
+ attention_head_dim: int = 112,
+ dropout: float = 0.0,
+ num_cross_attention_heads: Optional[int] = 20,
+ cross_attention_head_dim: Optional[int] = 112,
+ cross_attention_dim: Optional[int] = 2240,
+ attention_bias: bool = True,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ attention_out_bias: bool = True,
+ mlp_ratio: float = 3.0,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ rope_max_seq_len: int = 1024,
+ ) -> None:
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ kv_heads=num_attention_heads if qk_norm is not None else None,
+ qk_norm=qk_norm,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ processor=SanaLinearAttnProcessor3_0(),
+ )
+
+ # 2. Cross Attention
+ if cross_attention_dim is not None:
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.attn2 = Attention(
+ query_dim=dim,
+ qk_norm=qk_norm,
+ kv_heads=num_cross_attention_heads if qk_norm is not None else None,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_cross_attention_heads,
+ dim_head=cross_attention_head_dim,
+ dropout=dropout,
+ bias=True,
+ out_bias=attention_out_bias,
+ processor=SanaAttnProcessor2_0(),
+ )
+
+ # 3. Feed-forward
+ self.ff = GLUMBTempConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ frames: int = None,
+ height: int = None,
+ width: int = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ # 1. Modulation
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1)
+ ).unbind(dim=2)
+
+ # 2. Self Attention
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
+
+ attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb)
+ hidden_states = hidden_states + gate_msa * attn_output
+
+ # 3. Cross Attention
+ if self.attn2 is not None:
+ attn_output = self.attn2(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = ff_output.flatten(1, 3)
+ hidden_states = hidden_states + gate_mlp * ff_output
+
+ return hidden_states
+
+
+class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
+ r"""
+ A 3D Transformer model introduced in [Sana-Video](https://huggingface.co/papers/2509.24695) family of models.
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `20`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `112`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of Transformer blocks to use.
+ num_cross_attention_heads (`int`, *optional*, defaults to `20`):
+ The number of heads to use for cross-attention.
+ cross_attention_head_dim (`int`, *optional*, defaults to `112`):
+ The number of channels in each head for cross-attention.
+ cross_attention_dim (`int`, *optional*, defaults to `2240`):
+ The number of channels in the cross-attention output.
+ caption_channels (`int`, defaults to `2304`):
+ The number of channels in the caption embeddings.
+ mlp_ratio (`float`, defaults to `2.5`):
+ The expansion ratio to use in the GLUMBConv layer.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability.
+ attention_bias (`bool`, defaults to `False`):
+ Whether to use bias in the attention layer.
+ sample_size (`int`, defaults to `32`):
+ The base size of the input latent.
+ patch_size (`int`, defaults to `1`):
+ The size of the patches to use in the patch embedding layer.
+ norm_elementwise_affine (`bool`, defaults to `False`):
+ Whether to use elementwise affinity in the normalization layer.
+ norm_eps (`float`, defaults to `1e-6`):
+ The epsilon value for the normalization layer.
+ qk_norm (`str`, *optional*, defaults to `None`):
+ The normalization to use for the query and key.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["SanaVideoTransformerBlock", "SanaModulatedNorm"]
+ _skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ num_attention_heads: int = 20,
+ attention_head_dim: int = 112,
+ num_layers: int = 20,
+ num_cross_attention_heads: Optional[int] = 20,
+ cross_attention_head_dim: Optional[int] = 112,
+ cross_attention_dim: Optional[int] = 2240,
+ caption_channels: int = 2304,
+ mlp_ratio: float = 2.5,
+ dropout: float = 0.0,
+ attention_bias: bool = False,
+ sample_size: int = 30,
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ interpolation_scale: Optional[int] = None,
+ guidance_embeds: bool = False,
+ guidance_embeds_scale: float = 0.1,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ rope_max_seq_len: int = 1024,
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Patch & position embedding
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Additional condition embeddings
+ if guidance_embeds:
+ self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
+ else:
+ self.time_embed = AdaLayerNormSingle(inner_dim)
+
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+ self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
+
+ # 3. Transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ SanaVideoTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ num_cross_attention_heads=num_cross_attention_heads,
+ cross_attention_head_dim=cross_attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output blocks
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ guidance: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ if guidance is not None:
+ timestep, embedded_timestep = self.time_embed(
+ timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ timestep, embedded_timestep = self.time_embed(
+ timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ timestep = timestep.view(batch_size, -1, timestep.size(-1))
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
+
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ encoder_hidden_states = self.caption_norm(encoder_hidden_states)
+
+ # 2. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for index_block, block in enumerate(self.transformer_blocks):
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ rotary_emb,
+ )
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
+
+ else:
+ for index_block, block in enumerate(self.transformer_blocks):
+ hidden_states = block(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ rotary_emb,
+ )
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
+
+ # 3. Normalization
+ hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
+
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index e41fad220de6..05391e047b7a 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,19 +18,18 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
-from ...models.attention import FeedForward, JointTransformerBlock
-from ...models.attention_processor import (
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin, FeedForward, JointTransformerBlock
+from ..attention_processor import (
Attention,
- AttentionProcessor,
FusedJointAttnProcessor2_0,
JointAttnProcessor2_0,
)
-from ...models.modeling_utils import ModelMixin
-from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -78,7 +77,7 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
class SD3Transformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
+ ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
):
"""
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
@@ -214,77 +213,13 @@ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int
for module in self.children():
fn_recursive_feed_forward(module, None, 0)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -304,11 +239,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py
new file mode 100644
index 000000000000..2b9fc5b8d9fb
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py
@@ -0,0 +1,784 @@
+# Copyright 2025 The SkyReels Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ PixArtAlphaTextProjection,
+ TimestepEmbedding,
+ get_1d_rotary_pos_embed,
+ get_1d_sincos_pos_embed_from_grid,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin, get_parameter_dtype
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_qkv_projections(
+ attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
+):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+class SkyReelsV2AttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "SkyReelsV2AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: "SkyReelsV2Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states_img = hidden_states_img.flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class SkyReelsV2AttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "The SkyReelsV2AttnProcessor2_0 class is deprecated and will be removed in a future version. "
+ "Please use SkyReelsV2AttnProcessor instead. "
+ )
+ deprecate("SkyReelsV2AttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+ return SkyReelsV2AttnProcessor(*args, **kwargs)
+
+
+class SkyReelsV2Attention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = SkyReelsV2AttnProcessor
+ _available_processors = [SkyReelsV2AttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ is_cross_attention=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
+class SkyReelsV2ImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class SkyReelsV2Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, output_type: str = "pt"):
+ super().__init__()
+ self.num_channels = num_channels
+ self.output_type = output_type
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
+ original_shape = timesteps.shape
+ t_emb = get_1d_sincos_pos_embed_from_grid(
+ self.num_channels,
+ timesteps,
+ output_type=self.output_type,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ )
+ # Reshape back to maintain batch structure
+ if len(original_shape) > 1:
+ t_emb = t_emb.reshape(*original_shape, self.num_channels)
+ return t_emb
+
+
+class SkyReelsV2TimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ pos_embed_seq_len: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = SkyReelsV2Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = get_parameter_dtype(self.time_embedder)
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+class SkyReelsV2RotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ self.t_dim = t_dim
+ self.h_dim = h_dim
+ self.w_dim = w_dim
+
+ freqs_cos = []
+ freqs_sin = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
+class SkyReelsV2TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = SkyReelsV2Attention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=SkyReelsV2AttnProcessor(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = SkyReelsV2Attention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ cross_attention_dim_head=dim // num_heads,
+ processor=SkyReelsV2AttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ if temb.dim() == 3:
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+ elif temb.dim() == 4:
+ # For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
+ e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class SkyReelsV2Transformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `16`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `4096`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `8192`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `32`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ inject_sample_info (`bool`, defaults to `False`):
+ Whether to inject sample information into the model.
+ image_dim (`int`, *optional*):
+ The dimension of the image embeddings.
+ added_kv_proj_dim (`int`, *optional*):
+ The dimension of the added key/value projection.
+ rope_max_seq_len (`int`, defaults to `1024`):
+ The maximum sequence length for the rotary embeddings.
+ pos_embed_seq_len (`int`, *optional*):
+ The sequence length for the positional embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["SkyReelsV2TransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["SkyReelsV2TransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int, ...] = (1, 2, 2),
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 8192,
+ num_layers: int = 32,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
+ inject_sample_info: bool = False,
+ num_frame_per_block: int = 1,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ # image_embedding_dim=1280 for I2V model
+ self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ SkyReelsV2TransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ if inject_sample_info:
+ self.fps_embedding = nn.Embedding(2, inner_dim)
+ self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ enable_diffusion_forcing: bool = False,
+ fps: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ causal_mask = None
+ if self.config.num_frame_per_block > 1:
+ block_num = post_patch_num_frames // self.config.num_frame_per_block
+ range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(
+ self.config.num_frame_per_block
+ )
+ causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
+ causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
+ causal_mask = causal_mask.repeat(
+ 1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width
+ )
+ causal_mask = causal_mask.reshape(
+ post_patch_num_frames * post_patch_height * post_patch_width,
+ post_patch_num_frames * post_patch_height * post_patch_width,
+ )
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image
+ )
+
+ timestep_proj = timestep_proj.unflatten(-1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ if self.config.inject_sample_info:
+ fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
+
+ fps_emb = self.fps_embedding(fps)
+ if enable_diffusion_forcing:
+ timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat(
+ timestep.shape[1], 1, 1
+ )
+ else:
+ timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1))
+
+ if enable_diffusion_forcing:
+ b, f = timestep.shape
+ temb = temb.view(b, f, 1, 1, -1)
+ timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) # (b, f, 1, 1, 6, inner_dim)
+ temb = temb.repeat(1, 1, post_patch_height, post_patch_width, 1).flatten(1, 3)
+ timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(
+ 1, 3
+ ) # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
+ timestep_proj = timestep_proj.transpose(1, 2).contiguous() # (b, 6, f * pp_h * pp_w, inner_dim)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ causal_mask,
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ causal_mask,
+ )
+
+ if temb.dim() == 2:
+ # If temb is 2D, we assume it has time 1-D time embedding values for each batch.
+ # For models:
+ # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
+ # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
+ # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
+ # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
+ # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ elif temb.dim() == 3:
+ # If temb is 3D, we assume it has 2-D time embedding values for each batch.
+ # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
+ # For models:
+ # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
+ # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
+ # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
+ shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
+ shift, scale = shift.squeeze(1), scale.squeeze(1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+ def _set_ar_attention(self, causal_block_size: int):
+ self.register_to_config(num_frame_per_block=causal_block_size)
diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py
index 5580d0f70f9f..ffaf31d04570 100644
--- a/src/diffusers/models/transformers/transformer_temporal.py
+++ b/src/diffusers/models/transformers/transformer_temporal.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
index aa03e97093aa..f7693ec5d3ac 100644
--- a/src/diffusers/models/transformers/transformer_wan.py
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -21,9 +21,11 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ..attention import FeedForward
-from ..attention_processor import Attention
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
@@ -34,69 +36,120 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class WanAttnProcessor2_0:
+def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+class WanAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+ raise ImportError(
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
def __call__(
self,
- attn: Attention,
+ attn: "WanAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
- rotary_emb: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
- encoder_hidden_states_img = encoder_hidden_states[:, :257]
- encoder_hidden_states = encoder_hidden_states[:, 257:]
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
- query = attn.to_q(hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
if rotary_emb is not None:
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
- x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
- return x_out.type_as(hidden_states)
-
- query = apply_rotary_emb(query, rotary_emb)
- key = apply_rotary_emb(key, rotary_emb)
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
- key_img = attn.add_k_proj(encoder_hidden_states_img)
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
- value_img = attn.add_v_proj(encoder_hidden_states_img)
- key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
-
- hidden_states_img = F.scaled_dot_product_attention(
- query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
- hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
+ hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
@@ -107,15 +160,140 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
return hidden_states
+class WanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
+ "Please use WanAttnProcessor instead. "
+ )
+ deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+ return WanAttnProcessor(*args, **kwargs)
+
+
+class WanAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = WanAttnProcessor
+ _available_processors = [WanAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ is_cross_attention=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
class WanImageEmbedding(torch.nn.Module):
- def __init__(self, in_features: int, out_features: int):
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
@@ -130,6 +308,7 @@ def __init__(
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
+ pos_embed_seq_len: Optional[int] = None,
):
super().__init__()
@@ -141,15 +320,18 @@ def __init__(
self.image_embedder = None
if image_embed_dim is not None:
- self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ timestep_seq_len: Optional[int] = None,
):
timestep = self.timesteps_proj(timestep)
+ if timestep_seq_len is not None:
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
@@ -166,7 +348,11 @@ def forward(
class WanRotaryPosEmbed(nn.Module):
def __init__(
- self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
):
super().__init__()
@@ -177,36 +363,55 @@ def __init__(
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
- freqs = []
+ self.t_dim = t_dim
+ self.h_dim = h_dim
+ self.w_dim = w_dim
+
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
for dim in [t_dim, h_dim, w_dim]:
- freq = get_1d_rotary_pos_embed(
- dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
)
- freqs.append(freq)
- self.freqs = torch.cat(freqs, dim=1)
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
- self.freqs = self.freqs.to(hidden_states.device)
- freqs = self.freqs.split_with_sizes(
- [
- self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
- self.attention_head_dim // 6,
- self.attention_head_dim // 6,
- ],
- dim=1,
- )
+ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
- freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
- freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
- return freqs
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
class WanTransformerBlock(nn.Module):
def __init__(
self,
@@ -222,33 +427,24 @@ def __init__(
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
- self.attn1 = Attention(
- query_dim=dim,
+ self.attn1 = WanAttention(
+ dim=dim,
heads=num_heads,
- kv_heads=num_heads,
dim_head=dim // num_heads,
- qk_norm=qk_norm,
eps=eps,
- bias=True,
- cross_attention_dim=None,
- out_bias=True,
- processor=WanAttnProcessor2_0(),
+ cross_attention_dim_head=None,
+ processor=WanAttnProcessor(),
)
# 2. Cross-attention
- self.attn2 = Attention(
- query_dim=dim,
+ self.attn2 = WanAttention(
+ dim=dim,
heads=num_heads,
- kv_heads=num_heads,
dim_head=dim // num_heads,
- qk_norm=qk_norm,
eps=eps,
- bias=True,
- cross_attention_dim=None,
- out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
- added_proj_bias=True,
- processor=WanAttnProcessor2_0(),
+ cross_attention_dim_head=dim // num_heads,
+ processor=WanAttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
@@ -265,18 +461,32 @@ def forward(
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
- shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table + temb.float()
- ).chunk(6, dim=1)
+ if temb.ndim == 4:
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(0) + temb.float()
+ ).chunk(6, dim=2)
+ # batch_size, seq_len, 1, inner_dim
+ shift_msa = shift_msa.squeeze(2)
+ scale_msa = scale_msa.squeeze(2)
+ gate_msa = gate_msa.squeeze(2)
+ c_shift_msa = c_shift_msa.squeeze(2)
+ c_scale_msa = c_scale_msa.squeeze(2)
+ c_gate_msa = c_gate_msa.squeeze(2)
+ else:
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
- attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
- attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
@@ -289,7 +499,9 @@ def forward(
return hidden_states
-class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+class WanTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
r"""
A Transformer model for video-like data used in the Wan model.
@@ -331,11 +543,28 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["WanTransformerBlock"]
+ _cp_plan = {
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ },
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "blocks.*": {
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ "": {
+ "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ }
@register_to_config
def __init__(
self,
- patch_size: Tuple[int] = (1, 2, 2),
+ patch_size: Tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
@@ -350,6 +579,7 @@ def __init__(
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
) -> None:
super().__init__()
@@ -368,6 +598,7 @@ def __init__(
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
)
# 3. Transformer blocks
@@ -422,10 +653,22 @@ def forward(
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
+ if timestep.ndim == 2:
+ ts_seq_len = timestep.shape[1]
+ timestep = timestep.flatten() # batch_size * seq_len
+ else:
+ ts_seq_len = None
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
- timestep, encoder_hidden_states, encoder_hidden_states_image
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
- timestep_proj = timestep_proj.unflatten(1, (6, -1))
+ if ts_seq_len is not None:
+ # batch_size, seq_len, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+ else:
+ # batch_size, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
@@ -441,7 +684,14 @@ def forward(
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
# 5. Output norm, projection & unpatchify
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ if temb.ndim == 3:
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
+ shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift = shift.squeeze(2)
+ scale = scale.squeeze(2)
+ else:
+ # batch_size, inner_dim
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py
new file mode 100644
index 000000000000..6a47a67385a3
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_wan_animate.py
@@ -0,0 +1,1298 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = {
+ "4": 512,
+ "8": 512,
+ "16": 512,
+ "32": 512,
+ "64": 256,
+ "128": 128,
+ "256": 64,
+ "512": 32,
+ "1024": 16,
+}
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
+def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
+def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+class FusedLeakyReLU(nn.Module):
+ """
+ Fused LeakyRelu with scale factor and channel-wise bias.
+ """
+
+ def __init__(self, negative_slope: float = 0.2, scale: float = 2**0.5, bias_channels: Optional[int] = None):
+ super().__init__()
+ self.negative_slope = negative_slope
+ self.scale = scale
+ self.channels = bias_channels
+
+ if self.channels is not None:
+ self.bias = nn.Parameter(
+ torch.zeros(
+ self.channels,
+ )
+ )
+ else:
+ self.bias = None
+
+ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ if self.bias is not None:
+ # Expand self.bias to have all singleton dims except at self.channel_dim
+ expanded_shape = [1] * x.ndim
+ expanded_shape[channel_dim] = self.bias.shape[0]
+ bias = self.bias.reshape(*expanded_shape)
+ x = x + bias
+ return F.leaky_relu(x, self.negative_slope) * self.scale
+
+
+class MotionConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ bias: bool = True,
+ blur_kernel: Optional[Tuple[int, ...]] = None,
+ blur_upsample_factor: int = 1,
+ use_activation: bool = True,
+ ):
+ super().__init__()
+ self.use_activation = use_activation
+ self.in_channels = in_channels
+
+ # Handle blurring (applying a FIR filter with the given kernel) if available
+ self.blur = False
+ if blur_kernel is not None:
+ p = (len(blur_kernel) - stride) + (kernel_size - 1)
+ self.blur_padding = ((p + 1) // 2, p // 2)
+
+ kernel = torch.tensor(blur_kernel)
+ # Convert kernel to 2D if necessary
+ if kernel.ndim == 1:
+ kernel = kernel[None, :] * kernel[:, None]
+ # Normalize kernel
+ kernel = kernel / kernel.sum()
+ if blur_upsample_factor > 1:
+ kernel = kernel * (blur_upsample_factor**2)
+ self.register_buffer("blur_kernel", kernel, persistent=False)
+ self.blur = True
+
+ # Main Conv2d parameters (with scale factor)
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+
+ self.stride = stride
+ self.padding = padding
+
+ # If using an activation function, the bias will be fused into the activation
+ if bias and not self.use_activation:
+ self.bias = nn.Parameter(torch.zeros(out_channels))
+ else:
+ self.bias = None
+
+ if self.use_activation:
+ self.act_fn = FusedLeakyReLU(bias_channels=out_channels)
+ else:
+ self.act_fn = None
+
+ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ # Apply blur if using
+ if self.blur:
+ # NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
+ # set to 1, which should be equivalent to a 2D convolution
+ expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
+ x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
+
+ # Main Conv2D with scaling
+ x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
+
+ # Activation with fused bias, if using
+ if self.use_activation:
+ x = self.act_fn(x, channel_dim=channel_dim)
+ return x
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
+ f" kernel_size={self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
+ )
+
+
+class MotionLinear(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ bias: bool = True,
+ use_activation: bool = False,
+ ):
+ super().__init__()
+ self.use_activation = use_activation
+
+ # Linear weight with scale factor
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
+ self.scale = 1 / math.sqrt(in_dim)
+
+ # If an activation is present, the bias will be fused to it
+ if bias and not self.use_activation:
+ self.bias = nn.Parameter(torch.zeros(out_dim))
+ else:
+ self.bias = None
+
+ if self.use_activation:
+ self.act_fn = FusedLeakyReLU(bias_channels=out_dim)
+ else:
+ self.act_fn = None
+
+ def forward(self, input: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ out = F.linear(input, self.weight * self.scale, bias=self.bias)
+ if self.use_activation:
+ out = self.act_fn(out, channel_dim=channel_dim)
+ return out
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]},"
+ f" bias={self.bias is not None})"
+ )
+
+
+class MotionEncoderResBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ kernel_size_skip: int = 1,
+ blur_kernel: Tuple[int, ...] = (1, 3, 3, 1),
+ downsample_factor: int = 2,
+ ):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ # 3 x 3 Conv + fused leaky ReLU
+ self.conv1 = MotionConv2d(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ use_activation=True,
+ )
+
+ # 3 x 3 Conv that downsamples 2x + fused leaky ReLU
+ self.conv2 = MotionConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=self.downsample_factor,
+ padding=0,
+ blur_kernel=blur_kernel,
+ use_activation=True,
+ )
+
+ # 1 x 1 Conv that downsamples 2x in skip connection
+ self.conv_skip = MotionConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size_skip,
+ stride=self.downsample_factor,
+ padding=0,
+ bias=False,
+ blur_kernel=blur_kernel,
+ use_activation=False,
+ )
+
+ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ x_out = self.conv1(x, channel_dim)
+ x_out = self.conv2(x_out, channel_dim)
+
+ x_skip = self.conv_skip(x, channel_dim)
+
+ x_out = (x_out + x_skip) / math.sqrt(2)
+ return x_out
+
+
+class WanAnimateMotionEncoder(nn.Module):
+ def __init__(
+ self,
+ size: int = 512,
+ style_dim: int = 512,
+ motion_dim: int = 20,
+ out_dim: int = 512,
+ motion_blocks: int = 5,
+ channels: Optional[Dict[str, int]] = None,
+ ):
+ super().__init__()
+ self.size = size
+
+ # Appearance encoder: conv layers
+ if channels is None:
+ channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES
+
+ self.conv_in = MotionConv2d(3, channels[str(size)], 1, use_activation=True)
+
+ self.res_blocks = nn.ModuleList()
+ in_channels = channels[str(size)]
+ log_size = int(math.log(size, 2))
+ for i in range(log_size, 2, -1):
+ out_channels = channels[str(2 ** (i - 1))]
+ self.res_blocks.append(MotionEncoderResBlock(in_channels, out_channels))
+ in_channels = out_channels
+
+ self.conv_out = MotionConv2d(in_channels, style_dim, 4, padding=0, bias=False, use_activation=False)
+
+ # Motion encoder: linear layers
+ # NOTE: there are no activations in between the linear layers here, which is weird but I believe matches the
+ # original code.
+ linears = [MotionLinear(style_dim, style_dim) for _ in range(motion_blocks - 1)]
+ linears.append(MotionLinear(style_dim, motion_dim))
+ self.motion_network = nn.ModuleList(linears)
+
+ self.motion_synthesis_weight = nn.Parameter(torch.randn(out_dim, motion_dim))
+
+ def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
+ if (face_image.shape[-2] != self.size) or (face_image.shape[-1] != self.size):
+ raise ValueError(
+ f"Face pixel values has resolution ({face_image.shape[-1]}, {face_image.shape[-2]}) but is expected"
+ f" to have resolution ({self.size}, {self.size})"
+ )
+
+ # Appearance encoding through convs
+ face_image = self.conv_in(face_image, channel_dim)
+ for block in self.res_blocks:
+ face_image = block(face_image, channel_dim)
+ face_image = self.conv_out(face_image, channel_dim)
+ motion_feat = face_image.squeeze(-1).squeeze(-1)
+
+ # Motion feature extraction
+ for linear_layer in self.motion_network:
+ motion_feat = linear_layer(motion_feat, channel_dim=channel_dim)
+
+ # Motion synthesis via Linear Motion Decomposition
+ weight = self.motion_synthesis_weight + 1e-8
+ # Upcast the QR orthogonalization operation to FP32
+ original_motion_dtype = motion_feat.dtype
+ motion_feat = motion_feat.to(torch.float32)
+ weight = weight.to(torch.float32)
+
+ Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)
+
+ motion_feat_diag = torch.diag_embed(motion_feat) # Alpha, diagonal matrix
+ motion_decomposition = torch.matmul(motion_feat_diag, Q.T)
+ motion_vec = torch.sum(motion_decomposition, dim=1)
+
+ motion_vec = motion_vec.to(dtype=original_motion_dtype)
+
+ return motion_vec
+
+
+class WanAnimateFaceEncoder(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ hidden_dim: int = 1024,
+ num_heads: int = 4,
+ kernel_size: int = 3,
+ eps: float = 1e-6,
+ pad_mode: str = "replicate",
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.time_causal_padding = (kernel_size - 1, 0)
+ self.pad_mode = pad_mode
+
+ self.act = nn.SiLU()
+
+ self.conv1_local = nn.Conv1d(in_dim, hidden_dim * num_heads, kernel_size=kernel_size, stride=1)
+ self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2)
+ self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2)
+
+ self.norm1 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
+ self.norm2 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
+ self.norm3 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False)
+
+ self.out_proj = nn.Linear(hidden_dim, out_dim)
+
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, out_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size = x.shape[0]
+
+ # Reshape to channels-first to apply causal Conv1d over frame dim
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ x = self.conv1_local(x) # [B, C, T_padded] --> [B, N * C, T]
+ x = x.unflatten(1, (self.num_heads, -1)).flatten(0, 1) # [B, N * C, T] --> [B * N, C, T]
+ # Reshape back to channels-last to apply LayerNorm over channel dim
+ x = x.permute(0, 2, 1)
+ x = self.norm1(x)
+ x = self.act(x)
+
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ x = self.conv2(x)
+ x = x.permute(0, 2, 1)
+ x = self.norm2(x)
+ x = self.act(x)
+
+ x = x.permute(0, 2, 1)
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ x = self.conv3(x)
+ x = x.permute(0, 2, 1)
+ x = self.norm3(x)
+ x = self.act(x)
+
+ x = self.out_proj(x)
+ x = x.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # [B * N, T, C_out] --> [B, T, N, C_out]
+
+ padding = self.padding_tokens.repeat(batch_size, x.shape[1], 1, 1).to(device=x.device)
+ x = torch.cat([x, padding], dim=-2) # [B, T, N, C_out] --> [B, T, N + 1, C_out]
+
+ return x
+
+
+class WanAnimateFaceBlockAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or"
+ f" higher."
+ )
+
+ def __call__(
+ self,
+ attn: "WanAnimateFaceBlockCrossAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ # encoder_hidden_states corresponds to the motion vec
+ # attention_mask corresponds to the motion mask (if any)
+ hidden_states = attn.pre_norm_q(hidden_states)
+ encoder_hidden_states = attn.pre_norm_kv(encoder_hidden_states)
+
+ # B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
+ B, T, N, C = encoder_hidden_states.shape
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
+ key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
+ value = value.view(B, T, N, attn.heads, -1)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ # NOTE: the below line (which follows the official code) means that in practice, the number of frames T in
+ # encoder_hidden_states (the motion vector after applying the face encoder) must evenly divide the
+ # post-patchify sequence length S of the transformer hidden_states. Is it possible to remove this dependency?
+ query = query.unflatten(1, (T, -1)).flatten(0, 1) # [B, S, H, D] --> [B * T, S / T, H, D]
+ key = key.flatten(0, 1) # [B, T, N, H, D_kv] --> [B * T, N, H, D_kv]
+ value = value.flatten(0, 1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+ hidden_states = hidden_states.unflatten(0, (B, T)).flatten(1, 2)
+
+ hidden_states = attn.to_out(hidden_states)
+
+ if attention_mask is not None:
+ # NOTE: attention_mask is assumed to be a multiplicative mask
+ attention_mask = attention_mask.flatten(start_dim=1)
+ hidden_states = hidden_states * attention_mask
+
+ return hidden_states
+
+
+class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
+ """
+ Temporally-aligned cross attention with the face motion signal in the Wan Animate Face Blocks.
+ """
+
+ _default_processor_cls = WanAnimateFaceBlockAttnProcessor
+ _available_processors = [WanAnimateFaceBlockAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-6,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ ):
+ super().__init__()
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.cross_attention_head_dim = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ # 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
+ # NOTE: this is not used in "vanilla" WanAttention
+ self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False)
+ self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
+
+ # 2. QKV and Output Projections
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
+
+ # 3. QK Norm
+ # NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True)
+
+ # 4. Set attention processor
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask)
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
+class WanAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: "WanAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states_img = hidden_states_img.flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanAttention
+class WanAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = WanAttnProcessor
+ _available_processors = [WanAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ is_cross_attention=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding
+class WanImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
+class WanTimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ pos_embed_seq_len: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ timestep_seq_len: Optional[int] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+ if timestep_seq_len is not None:
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
+
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
+class WanRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+
+ self.t_dim = t_dim
+ self.h_dim = h_dim
+ self.w_dim = w_dim
+
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanTransformerBlock
+class WanTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ cross_attention_dim_head=None,
+ processor=WanAttnProcessor(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ cross_attention_dim_head=dim // num_heads,
+ processor=WanAttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ if temb.ndim == 4:
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(0) + temb.float()
+ ).chunk(6, dim=2)
+ # batch_size, seq_len, 1, inner_dim
+ shift_msa = shift_msa.squeeze(2)
+ scale_msa = scale_msa.squeeze(2)
+ gate_msa = gate_msa.squeeze(2)
+ c_shift_msa = c_shift_msa.squeeze(2)
+ c_scale_msa = c_scale_msa.squeeze(2)
+ c_gate_msa = c_gate_msa.squeeze(2)
+ else:
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+
+class WanAnimateTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the WanAnimate model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ image_dim (`int`, *optional*, defaults to `1280`):
+ The number of channels to use for the image embedding. If `None`, no projection is used.
+ added_kv_proj_dim (`int`, *optional*, defaults to `5120`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["WanTransformerBlock", "MotionEncoderResBlock"]
+ _keep_in_fp32_modules = [
+ "time_embedder",
+ "scale_shift_table",
+ "norm1",
+ "norm2",
+ "norm3",
+ "motion_synthesis_weight",
+ ]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["WanTransformerBlock"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: Optional[int] = 36,
+ latent_channels: Optional[int] = 16,
+ out_channels: Optional[int] = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = 1280,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
+ motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, # Start of Wan Animate-specific args
+ motion_encoder_size: int = 512,
+ motion_style_dim: int = 512,
+ motion_dim: int = 20,
+ motion_encoder_dim: int = 512,
+ face_encoder_hidden_dim: int = 1024,
+ face_encoder_num_heads: int = 4,
+ inject_face_latents_blocks: int = 5,
+ motion_encoder_batch_size: int = 8,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ # Allow either only in_channels or only latent_channels to be set for convenience
+ if in_channels is None and latent_channels is not None:
+ in_channels = 2 * latent_channels + 4
+ elif in_channels is not None and latent_channels is None:
+ latent_channels = (in_channels - 4) // 2
+ elif in_channels is not None and latent_channels is not None:
+ # TODO: should this always be true?
+ assert in_channels == 2 * latent_channels + 4, "in_channels should be 2 * latent_channels + 4"
+ else:
+ raise ValueError("At least one of `in_channels` and `latent_channels` must be supplied.")
+ out_channels = out_channels or latent_channels
+
+ # 1. Patch & position embedding
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+ self.pose_patch_embedding = nn.Conv3d(latent_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ self.condition_embedder = WanTimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ )
+
+ # Motion encoder
+ self.motion_encoder = WanAnimateMotionEncoder(
+ size=motion_encoder_size,
+ style_dim=motion_style_dim,
+ motion_dim=motion_dim,
+ out_dim=motion_encoder_dim,
+ channels=motion_encoder_channel_sizes,
+ )
+
+ # Face encoder
+ self.face_encoder = WanAnimateFaceEncoder(
+ in_dim=motion_encoder_dim,
+ out_dim=inner_dim,
+ hidden_dim=face_encoder_hidden_dim,
+ num_heads=face_encoder_num_heads,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ WanTransformerBlock(
+ dim=inner_dim,
+ ffn_dim=ffn_dim,
+ num_heads=num_attention_heads,
+ qk_norm=qk_norm,
+ cross_attn_norm=cross_attn_norm,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.face_adapter = nn.ModuleList(
+ [
+ WanAnimateFaceBlockCrossAttention(
+ dim=inner_dim,
+ heads=num_attention_heads,
+ dim_head=inner_dim // num_attention_heads,
+ eps=eps,
+ cross_attention_dim_head=inner_dim // num_attention_heads,
+ processor=WanAnimateFaceBlockAttnProcessor(),
+ )
+ for _ in range(num_layers // inject_face_latents_blocks)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ pose_hidden_states: Optional[torch.Tensor] = None,
+ face_pixel_values: Optional[torch.Tensor] = None,
+ motion_encode_batch_size: Optional[int] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Forward pass of Wan2.2-Animate transformer model.
+
+ Args:
+ hidden_states (`torch.Tensor` of shape `(B, 2C + 4, T + 1, H, W)`):
+ Input noisy video latents of shape `(B, 2C + 4, T + 1, H, W)`, where B is the batch size, C is the
+ number of latent channels (16 for Wan VAE), T is the number of latent frames in an inference segment, H
+ is the latent height, and W is the latent width.
+ timestep: (`torch.LongTensor`):
+ The current timestep in the denoising loop.
+ encoder_hidden_states (`torch.Tensor`):
+ Text embeddings from the text encoder (umT5 for Wan Animate).
+ encoder_hidden_states_image (`torch.Tensor`):
+ CLIP visual features of the reference (character) image.
+ pose_hidden_states (`torch.Tensor` of shape `(B, C, T, H, W)`):
+ Pose video latents. TODO: description
+ face_pixel_values (`torch.Tensor` of shape `(B, C', S, H', W')`):
+ Face video in pixel space (not latent space). Typically C' = 3 and H' and W' are the height/width of
+ the face video in pixels. Here S is the inference segment length, usually set to 77.
+ motion_encode_batch_size (`int`, *optional*):
+ The batch size for batched encoding of the face video via the motion encoder. Will default to
+ `self.config.motion_encoder_batch_size` if not set.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return the output as a dict or tuple.
+ """
+
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # Check that shapes match up
+ if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]:
+ raise ValueError(
+ f"pose_hidden_states frame dim (dim 2) is {pose_hidden_states.shape[2]} but must be one less than the"
+ f" hidden_states's corresponding frame dim: {hidden_states.shape[2]}"
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ # 1. Rotary position embedding
+ rotary_emb = self.rope(hidden_states)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embedding(hidden_states)
+ pose_hidden_states = self.pose_patch_embedding(pose_hidden_states)
+ # Add pose embeddings to hidden states
+ hidden_states[:, :, 1:] = hidden_states[:, :, 1:] + pose_hidden_states
+ # Calling contiguous() here is important so that we don't recompile when performing regional compilation
+ hidden_states = hidden_states.flatten(2).transpose(1, 2).contiguous()
+
+ # 3. Condition embeddings (time, text, image)
+ # Wan Animate is based on Wan 2.1 and thus uses Wan 2.1's timestep logic
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=None
+ )
+
+ # batch_size, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 4. Get motion features from the face video
+ # Motion vector computation from face pixel values
+ batch_size, channels, num_face_frames, height, width = face_pixel_values.shape
+ # Rearrange from (B, C, T, H, W) to (B*T, C, H, W)
+ face_pixel_values = face_pixel_values.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
+
+ # Extract motion features using motion encoder
+ # Perform batched motion encoder inference to allow trading off inference speed for memory usage
+ motion_encode_batch_size = motion_encode_batch_size or self.config.motion_encoder_batch_size
+ face_batches = torch.split(face_pixel_values, motion_encode_batch_size)
+ motion_vec_batches = []
+ for face_batch in face_batches:
+ motion_vec_batch = self.motion_encoder(face_batch)
+ motion_vec_batches.append(motion_vec_batch)
+ motion_vec = torch.cat(motion_vec_batches)
+ motion_vec = motion_vec.view(batch_size, num_face_frames, -1)
+
+ # Now get face features from the motion vector
+ motion_vec = self.face_encoder(motion_vec)
+
+ # Add padding at the beginning (prepend zeros)
+ pad_face = torch.zeros_like(motion_vec[:, :1])
+ motion_vec = torch.cat([pad_face, motion_vec], dim=1)
+
+ # 5. Transformer blocks with face adapter integration
+ for block_idx, block in enumerate(self.blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
+ )
+ else:
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+
+ # Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...)
+ if block_idx % self.config.inject_face_latents_blocks == 0:
+ face_adapter_block_idx = block_idx // self.config.inject_face_latents_blocks
+ face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec)
+ # In case the face adapter and main transformer blocks are on different devices, which can happen when
+ # using model parallelism
+ face_adapter_output = face_adapter_output.to(device=hidden_states.device)
+ hidden_states = face_adapter_output + hidden_states
+
+ # 6. Output norm, projection & unpatchify
+ # batch_size, inner_dim
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ hidden_states_original_dtype = hidden_states.dtype
+ hidden_states = self.norm_out(hidden_states.float())
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+ hidden_states = (hidden_states * (1 + scale) + shift).to(dtype=hidden_states_original_dtype)
+
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py
new file mode 100644
index 000000000000..1be4f73e33e2
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_wan_vace.py
@@ -0,0 +1,389 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin, FeedForward
+from ..cache_utils import CacheMixin
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import FP32LayerNorm
+from .transformer_wan import (
+ WanAttention,
+ WanAttnProcessor,
+ WanRotaryPosEmbed,
+ WanTimeTextImageEmbedding,
+ WanTransformerBlock,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanVACETransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ apply_input_projection: bool = False,
+ apply_output_projection: bool = False,
+ ):
+ super().__init__()
+
+ # 1. Input projection
+ self.proj_in = None
+ if apply_input_projection:
+ self.proj_in = nn.Linear(dim, dim)
+
+ # 2. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ processor=WanAttnProcessor(),
+ )
+
+ # 3. Cross-attention
+ self.attn2 = WanAttention(
+ dim=dim,
+ heads=num_heads,
+ dim_head=dim // num_heads,
+ eps=eps,
+ added_kv_proj_dim=added_kv_proj_dim,
+ processor=WanAttnProcessor(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 4. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ # 5. Output projection
+ self.proj_out = None
+ if apply_output_projection:
+ self.proj_out = nn.Linear(dim, dim)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ control_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.proj_in is not None:
+ control_hidden_states = self.proj_in(control_hidden_states)
+ control_hidden_states = control_hidden_states + hidden_states
+
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.to(temb.device) + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
+ control_hidden_states
+ )
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
+ control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
+ control_hidden_states = control_hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ control_hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
+ control_hidden_states
+ )
+
+ conditioning_states = None
+ if self.proj_out is not None:
+ conditioning_states = self.proj_out(control_hidden_states)
+
+ return conditioning_states, control_hidden_states
+
+
+class WanVACETransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
+ r"""
+ A Transformer model for video-like data used in the Wan model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `40`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `512`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `13824`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `40`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`bool`, defaults to `True`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ add_img_emb (`bool`, defaults to `False`):
+ Whether to use img_emb.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int, ...] = (1, 2, 2),
+ num_attention_heads: int = 40,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 13824,
+ num_layers: int = 40,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
+ vace_layers: List[int] = [0, 5, 10, 15, 20, 25, 30, 35],
+ vace_in_channels: int = 96,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ if max(vace_layers) >= num_layers:
+ raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.")
+ if 0 not in vace_layers:
+ raise ValueError("VACE layers must include layer 0.")
+
+ # 1. Patch & position embedding
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+ self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ # image_embedding_dim=1280 for I2V model
+ self.condition_embedder = WanTimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ WanTransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.vace_blocks = nn.ModuleList(
+ [
+ WanVACETransformerBlock(
+ inner_dim,
+ ffn_dim,
+ num_attention_heads,
+ qk_norm,
+ cross_attn_norm,
+ eps,
+ added_kv_proj_dim,
+ apply_input_projection=i == 0, # Layer 0 always has input projection and is in vace_layers
+ apply_output_projection=True,
+ )
+ for i in range(len(vace_layers))
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ control_hidden_states: torch.Tensor = None,
+ control_hidden_states_scale: torch.Tensor = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ if control_hidden_states_scale is None:
+ control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
+ control_hidden_states_scale = torch.unbind(control_hidden_states_scale)
+ if len(control_hidden_states_scale) != len(self.config.vace_layers):
+ raise ValueError(
+ f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be "
+ f"equal to {len(self.config.vace_layers)}."
+ )
+
+ # 1. Rotary position embedding
+ rotary_emb = self.rope(hidden_states)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ control_hidden_states = self.vace_patch_embedding(control_hidden_states)
+ control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
+ control_hidden_states_padding = control_hidden_states.new_zeros(
+ batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
+ )
+ control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1)
+
+ # 3. Time embedding
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image
+ )
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ # 4. Image embedding
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ # 5. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ # Prepare VACE hints
+ control_hidden_states_list = []
+ for i, block in enumerate(self.vace_blocks):
+ conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
+ )
+ control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
+ control_hidden_states_list = control_hidden_states_list[::-1]
+
+ for i, block in enumerate(self.blocks):
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
+ )
+ if i in self.config.vace_layers:
+ control_hint, scale = control_hidden_states_list.pop()
+ hidden_states = hidden_states + control_hint * scale
+ else:
+ # Prepare VACE hints
+ control_hidden_states_list = []
+ for i, block in enumerate(self.vace_blocks):
+ conditioning_states, control_hidden_states = block(
+ hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
+ )
+ control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
+ control_hidden_states_list = control_hidden_states_list[::-1]
+
+ for i, block in enumerate(self.blocks):
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
+ if i in self.config.vace_layers:
+ control_hint, scale = control_hidden_states_list.pop()
+ hidden_states = hidden_states + control_hint * scale
+
+ # 6. Output norm, projection & unpatchify
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py
new file mode 100644
index 000000000000..5c401b9d202b
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_z_image.py
@@ -0,0 +1,653 @@
+# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...models.attention_processor import Attention
+from ...models.modeling_utils import ModelMixin
+from ...models.normalization import RMSNorm
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention_dispatch import dispatch_attention_fn
+from ..modeling_outputs import Transformer2DModelOutput
+
+
+ADALN_EMBED_DIM = 256
+SEQ_MULTI_OF = 32
+
+
+class TimestepEmbedder(nn.Module):
+ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
+ super().__init__()
+ if mid_size is None:
+ mid_size = out_size
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, mid_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(mid_size, out_size, bias=True),
+ )
+
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ with torch.amp.autocast("cuda", enabled=False):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ weight_dtype = self.mlp[0].weight.dtype
+ compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
+ if weight_dtype.is_floating_point:
+ t_freq = t_freq.to(weight_dtype)
+ elif compute_dtype is not None:
+ t_freq = t_freq.to(compute_dtype)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class ZSingleStreamAttnProcessor:
+ """
+ Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
+ original Z-ImageAttention module.
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ # Apply Norms
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ with torch.amp.autocast("cuda", enabled=False):
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
+ return x_out.type_as(x_in) # todo
+
+ if freqs_cis is not None:
+ query = apply_rotary_emb(query, freqs_cis)
+ key = apply_rotary_emb(key, freqs_cis)
+
+ # Cast to correct dtype
+ dtype = query.dtype
+ query, key = query.to(dtype), key.to(dtype)
+
+ # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
+ if attention_mask is not None and attention_mask.ndim == 2:
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Compute joint attention
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ # Reshape back
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(dtype)
+
+ output = attn.to_out[0](hidden_states)
+ if len(attn.to_out) > 1: # dropout
+ output = attn.to_out[1](output)
+
+ return output
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim: int, hidden_dim: int):
+ super().__init__()
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+
+ def _forward_silu_gating(self, x1, x3):
+ return F.silu(x1) * x3
+
+ def forward(self, x):
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
+
+
+@maybe_allow_in_graph
+class ZImageTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ layer_id: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ norm_eps: float,
+ qk_norm: bool,
+ modulation=True,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.head_dim = dim // n_heads
+
+ # Refactored to use diffusers Attention with custom processor
+ # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
+ self.attention = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=dim // n_heads,
+ heads=n_heads,
+ qk_norm="rms_norm" if qk_norm else None,
+ eps=1e-5,
+ bias=False,
+ out_bias=False,
+ processor=ZSingleStreamAttnProcessor(),
+ )
+
+ self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
+ self.layer_id = layer_id
+
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
+
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
+
+ self.modulation = modulation
+ if modulation:
+ self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_mask: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ adaln_input: Optional[torch.Tensor] = None,
+ ):
+ if self.modulation:
+ assert adaln_input is not None
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
+ gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
+ scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
+
+ # Attention block
+ attn_out = self.attention(
+ self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
+ )
+ x = x + gate_msa * self.attention_norm2(attn_out)
+
+ # FFN block
+ x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
+ else:
+ # Attention block
+ attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
+ x = x + self.attention_norm2(attn_out)
+
+ # FFN block
+ x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
+
+ return x
+
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
+ )
+
+ def forward(self, x, c):
+ scale = 1.0 + self.adaLN_modulation(c)
+ x = self.norm_final(x) * scale.unsqueeze(1)
+ x = self.linear(x)
+ return x
+
+
+class RopeEmbedder:
+ def __init__(
+ self,
+ theta: float = 256.0,
+ axes_dims: List[int] = (16, 56, 56),
+ axes_lens: List[int] = (64, 128, 128),
+ ):
+ self.theta = theta
+ self.axes_dims = axes_dims
+ self.axes_lens = axes_lens
+ assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
+ self.freqs_cis = None
+
+ @staticmethod
+ def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
+ with torch.device("cpu"):
+ freqs_cis = []
+ for i, (d, e) in enumerate(zip(dim, end)):
+ freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
+ timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
+ freqs = torch.outer(timestep, freqs).float()
+ freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
+ freqs_cis.append(freqs_cis_i)
+
+ return freqs_cis
+
+ def __call__(self, ids: torch.Tensor):
+ assert ids.ndim == 2
+ assert ids.shape[-1] == len(self.axes_dims)
+ device = ids.device
+
+ if self.freqs_cis is None:
+ self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
+ self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
+ else:
+ # Ensure freqs_cis are on the same device as ids
+ if self.freqs_cis[0].device != device:
+ self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
+
+ result = []
+ for i in range(len(self.axes_dims)):
+ index = ids[:, i]
+ result.append(self.freqs_cis[i][index])
+ return torch.cat(result, dim=-1)
+
+
+class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["ZImageTransformerBlock"]
+ _repeated_blocks = ["ZImageTransformerBlock"]
+ _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
+
+ @register_to_config
+ def __init__(
+ self,
+ all_patch_size=(2,),
+ all_f_patch_size=(1,),
+ in_channels=16,
+ dim=3840,
+ n_layers=30,
+ n_refiner_layers=2,
+ n_heads=30,
+ n_kv_heads=30,
+ norm_eps=1e-5,
+ qk_norm=True,
+ cap_feat_dim=2560,
+ rope_theta=256.0,
+ t_scale=1000.0,
+ axes_dims=[32, 48, 48],
+ axes_lens=[1024, 512, 512],
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.all_patch_size = all_patch_size
+ self.all_f_patch_size = all_f_patch_size
+ self.dim = dim
+ self.n_heads = n_heads
+
+ self.rope_theta = rope_theta
+ self.t_scale = t_scale
+ self.gradient_checkpointing = False
+
+ assert len(all_patch_size) == len(all_f_patch_size)
+
+ all_x_embedder = {}
+ all_final_layer = {}
+ for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
+ x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
+ all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
+
+ final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
+ all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
+
+ self.all_x_embedder = nn.ModuleDict(all_x_embedder)
+ self.all_final_layer = nn.ModuleDict(all_final_layer)
+ self.noise_refiner = nn.ModuleList(
+ [
+ ZImageTransformerBlock(
+ 1000 + layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ norm_eps,
+ qk_norm,
+ modulation=True,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+ self.context_refiner = nn.ModuleList(
+ [
+ ZImageTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ norm_eps,
+ qk_norm,
+ modulation=False,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+ self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
+ self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
+
+ self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
+ self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
+
+ self.layers = nn.ModuleList(
+ [
+ ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
+ for layer_id in range(n_layers)
+ ]
+ )
+ head_dim = dim // n_heads
+ assert head_dim == sum(axes_dims)
+ self.axes_dims = axes_dims
+ self.axes_lens = axes_lens
+
+ self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
+
+ def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
+ pH = pW = patch_size
+ pF = f_patch_size
+ bsz = len(x)
+ assert len(size) == bsz
+ for i in range(bsz):
+ F, H, W = size[i]
+ ori_len = (F // pF) * (H // pH) * (W // pW)
+ # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
+ x[i] = (
+ x[i][:ori_len]
+ .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
+ .permute(6, 0, 3, 1, 4, 2, 5)
+ .reshape(self.out_channels, F, H, W)
+ )
+ return x
+
+ @staticmethod
+ def create_coordinate_grid(size, start=None, device=None):
+ if start is None:
+ start = (0 for _ in size)
+
+ axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
+ grids = torch.meshgrid(axes, indexing="ij")
+ return torch.stack(grids, dim=-1)
+
+ def patchify_and_embed(
+ self,
+ all_image: List[torch.Tensor],
+ all_cap_feats: List[torch.Tensor],
+ patch_size: int,
+ f_patch_size: int,
+ ):
+ pH = pW = patch_size
+ pF = f_patch_size
+ device = all_image[0].device
+
+ all_image_out = []
+ all_image_size = []
+ all_image_pos_ids = []
+ all_image_pad_mask = []
+ all_cap_pos_ids = []
+ all_cap_pad_mask = []
+ all_cap_feats_out = []
+
+ for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
+ ### Process Caption
+ cap_ori_len = len(cap_feat)
+ cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
+ # padded position ids
+ cap_padded_pos_ids = self.create_coordinate_grid(
+ size=(cap_ori_len + cap_padding_len, 1, 1),
+ start=(1, 0, 0),
+ device=device,
+ ).flatten(0, 2)
+ all_cap_pos_ids.append(cap_padded_pos_ids)
+ # pad mask
+ cap_pad_mask = torch.cat(
+ [
+ torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
+ torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
+ ],
+ dim=0,
+ )
+ all_cap_pad_mask.append(
+ cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
+ )
+
+ # padded feature
+ cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
+ all_cap_feats_out.append(cap_padded_feat)
+
+ ### Process Image
+ C, F, H, W = image.size()
+ all_image_size.append((F, H, W))
+ F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
+
+ image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
+ # "c f pf h ph w pw -> (f h w) (pf ph pw c)"
+ image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
+
+ image_ori_len = len(image)
+ image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
+
+ image_ori_pos_ids = self.create_coordinate_grid(
+ size=(F_tokens, H_tokens, W_tokens),
+ start=(cap_ori_len + cap_padding_len + 1, 0, 0),
+ device=device,
+ ).flatten(0, 2)
+ image_padded_pos_ids = torch.cat(
+ [
+ image_ori_pos_ids,
+ self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
+ .flatten(0, 2)
+ .repeat(image_padding_len, 1),
+ ],
+ dim=0,
+ )
+ all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
+ # pad mask
+ image_pad_mask = torch.cat(
+ [
+ torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
+ torch.ones((image_padding_len,), dtype=torch.bool, device=device),
+ ],
+ dim=0,
+ )
+ all_image_pad_mask.append(
+ image_pad_mask
+ if image_padding_len > 0
+ else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
+ )
+ # padded feature
+ image_padded_feat = torch.cat(
+ [image, image[-1:].repeat(image_padding_len, 1)],
+ dim=0,
+ )
+ all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
+
+ return (
+ all_image_out,
+ all_cap_feats_out,
+ all_image_size,
+ all_image_pos_ids,
+ all_cap_pos_ids,
+ all_image_pad_mask,
+ all_cap_pad_mask,
+ )
+
+ def forward(
+ self,
+ x: List[torch.Tensor],
+ t,
+ cap_feats: List[torch.Tensor],
+ patch_size=2,
+ f_patch_size=1,
+ return_dict: bool = True,
+ ):
+ assert patch_size in self.all_patch_size
+ assert f_patch_size in self.all_f_patch_size
+
+ bsz = len(x)
+ device = x[0].device
+ t = t * self.t_scale
+ t = self.t_embedder(t)
+
+ (
+ x,
+ cap_feats,
+ x_size,
+ x_pos_ids,
+ cap_pos_ids,
+ x_inner_pad_mask,
+ cap_inner_pad_mask,
+ ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
+
+ # x embed & refine
+ x_item_seqlens = [len(_) for _ in x]
+ assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
+ x_max_item_seqlen = max(x_item_seqlens)
+
+ x = torch.cat(x, dim=0)
+ x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
+
+ # Match t_embedder output dtype to x for layerwise casting compatibility
+ adaln_input = t.type_as(x)
+ x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
+ x = list(x.split(x_item_seqlens, dim=0))
+ x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
+
+ x = pad_sequence(x, batch_first=True, padding_value=0.0)
+ x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
+ # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
+ x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
+
+ x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
+ for i, seq_len in enumerate(x_item_seqlens):
+ x_attn_mask[i, :seq_len] = 1
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for layer in self.noise_refiner:
+ x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
+ else:
+ for layer in self.noise_refiner:
+ x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
+
+ # cap embed & refine
+ cap_item_seqlens = [len(_) for _ in cap_feats]
+ cap_max_item_seqlen = max(cap_item_seqlens)
+
+ cap_feats = torch.cat(cap_feats, dim=0)
+ cap_feats = self.cap_embedder(cap_feats)
+ cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
+ cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
+ cap_freqs_cis = list(
+ self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
+ )
+
+ cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
+ cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
+ # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
+ cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
+
+ cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
+ for i, seq_len in enumerate(cap_item_seqlens):
+ cap_attn_mask[i, :seq_len] = 1
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for layer in self.context_refiner:
+ cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
+ else:
+ for layer in self.context_refiner:
+ cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
+
+ # unified
+ unified = []
+ unified_freqs_cis = []
+ for i in range(bsz):
+ x_len = x_item_seqlens[i]
+ cap_len = cap_item_seqlens[i]
+ unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
+ unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
+ unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
+ assert unified_item_seqlens == [len(_) for _ in unified]
+ unified_max_item_seqlen = max(unified_item_seqlens)
+
+ unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
+ unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
+ unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
+ for i, seq_len in enumerate(unified_item_seqlens):
+ unified_attn_mask[i, :seq_len] = 1
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for layer in self.layers:
+ unified = self._gradient_checkpointing_func(
+ layer, unified, unified_attn_mask, unified_freqs_cis, adaln_input
+ )
+ else:
+ for layer in self.layers:
+ unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
+
+ unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
+ unified = list(unified.unbind(dim=0))
+ x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
+
+ if not return_dict:
+ return (x,)
+
+ return Transformer2DModelOutput(sample=x)
diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py
index ce496fd6baf8..a027c553ed06 100644
--- a/src/diffusers/models/unets/unet_1d.py
+++ b/src/diffusers/models/unets/unet_1d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -82,14 +82,15 @@ def __init__(
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
+ time_embedding_dim: Optional[int] = None,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
- down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
- up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
- mid_block_type: Tuple[str] = "UNetMidBlock1D",
+ down_block_types: Tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
+ up_block_types: Tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
+ mid_block_type: str = "UNetMidBlock1D",
out_block_type: str = None,
- block_out_channels: Tuple[int] = (32, 32, 64),
+ block_out_channels: Tuple[int, ...] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
@@ -100,15 +101,23 @@ def __init__(
# time
if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
- embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
- timestep_input_dim = 2 * block_out_channels[0]
+ timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
diff --git a/src/diffusers/models/unets/unet_1d_blocks.py b/src/diffusers/models/unets/unet_1d_blocks.py
index f08e6070845e..58cbdfd005b6 100644
--- a/src/diffusers/models/unets/unet_1d_blocks.py
+++ b/src/diffusers/models/unets/unet_1d_blocks.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py
index 448ec051a032..2588a9c518bd 100644
--- a/src/diffusers/models/unets/unet_2d.py
+++ b/src/diffusers/models/unets/unet_2d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py
index e082d524e766..94a9245e567c 100644
--- a/src/diffusers/models/unets/unet_2d_blocks.py
+++ b/src/diffusers/models/unets/unet_2d_blocks.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/unets/unet_2d_blocks_flax.py b/src/diffusers/models/unets/unet_2d_blocks_flax.py
index a4585dbc8823..6e6005afdc31 100644
--- a/src/diffusers/models/unets/unet_2d_blocks_flax.py
+++ b/src/diffusers/models/unets/unet_2d_blocks_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,14 +15,18 @@
import flax.linen as nn
import jax.numpy as jnp
+from ...utils import logging
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
+logger = logging.get_logger(__name__)
+
+
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
- https://arxiv.org/abs/2103.06104
+ https://huggingface.co/papers/2103.06104
Parameters:
in_channels (:obj:`int`):
@@ -38,7 +42,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
- enable memory efficient attention https://arxiv.org/abs/2112.05682
+ enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
@@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
attentions = []
@@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
for i in range(self.num_layers):
@@ -169,7 +183,7 @@ def __call__(self, hidden_states, temb, deterministic=True):
class FlaxCrossAttnUpBlock2D(nn.Module):
r"""
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
- https://arxiv.org/abs/2103.06104
+ https://huggingface.co/papers/2103.06104
Parameters:
in_channels (:obj:`int`):
@@ -185,7 +199,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
- enable memory efficient attention https://arxiv.org/abs/2112.05682
+ enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
@@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
attentions = []
@@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
for i in range(self.num_layers):
@@ -324,7 +348,8 @@ def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=T
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
r"""
- Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
+ Cross Attention 2D Mid-level block - original architecture from Unet transformers:
+ https://huggingface.co/papers/2103.06104
Parameters:
in_channels (:obj:`int`):
@@ -336,7 +361,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
- enable memory efficient attention https://arxiv.org/abs/2112.05682
+ enable memory efficient attention https://huggingface.co/papers/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
@@ -355,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 2fd15f6f91e0..0bd135b57358 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
@@ -30,11 +29,11 @@
unscale_lora_layers,
)
from ..activations import get_activation
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
@@ -71,11 +70,7 @@ class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel(
- ModelMixin,
- ConfigMixin,
- FromOriginalModelMixin,
- UNet2DConditionLoadersMixin,
- PeftAdapterMixin,
+ ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
@@ -171,8 +166,9 @@ class conditioning with `class_embed_type` equal to `None`.
"""
_supports_gradient_checkpointing = True
- _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
+ _repeated_blocks = ["BasicTransformerBlock"]
@register_to_config
def __init__(
@@ -183,21 +179,21 @@ def __init__(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
- down_block_types: Tuple[str] = (
+ down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
- up_block_types: Tuple[str] = (
+ up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -519,10 +515,10 @@ def __init__(
def _check_config(
self,
- down_block_types: Tuple[str],
- up_block_types: Tuple[str],
+ down_block_types: Tuple[str, ...],
+ up_block_types: Tuple[str, ...],
only_cross_attention: Union[bool, Tuple[bool]],
- block_out_channels: Tuple[int],
+ block_out_channels: Tuple[int, ...],
layers_per_block: Union[int, Tuple[int]],
cross_attention_dim: Union[int, Tuple[int]],
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
@@ -773,70 +769,6 @@ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: i
feature_type=feature_type,
)
- @property
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(
- name: str,
- module: torch.nn.Module,
- processors: Dict[str, AttentionProcessor],
- ):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- def set_attn_processor(
- self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
- ):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
@@ -930,7 +862,7 @@ def fn_recursive_set_attention_slice(
fn_recursive_set_attention_slice(module, reversed_slice_size)
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+ r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -969,11 +901,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -994,11 +922,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py
index edbbcbaeda73..8d9a309afbcc 100644
--- a/src/diffusers/models/unets/unet_2d_condition_flax.py
+++ b/src/diffusers/models/unets/unet_2d_condition_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
-from ...utils import BaseOutput
+from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
@@ -32,6 +32,9 @@
)
+logger = logging.get_logger(__name__)
+
+
@flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
@@ -94,7 +97,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
- Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
+ Enable memory efficient attention as described [here](https://huggingface.co/papers/2112.05682).
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
@@ -163,6 +166,11 @@ def init_weights(self, rng: jax.Array) -> FrozenDict:
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
def setup(self) -> None:
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py
index 8d7614a23383..53c0f4bae38b 100644
--- a/src/diffusers/models/unets/unet_3d_blocks.py
+++ b/src/diffusers/models/unets/unet_3d_blocks.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py
index a148cf6cbe06..26dc50f84acd 100644
--- a/src/diffusers/models/unets/unet_3d_condition.py
+++ b/src/diffusers/models/unets/unet_3d_condition.py
@@ -1,5 +1,5 @@
-# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
-# Copyright 2024 The ModelScope Team.
+# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The ModelScope Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,17 +18,16 @@
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, logging
from ..activations import get_activation
+from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
@@ -59,7 +58,7 @@ class UNet3DConditionOutput(BaseOutput):
sample: torch.Tensor
-class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+class UNet3DConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
@@ -286,31 +285,6 @@ def __init__(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
@@ -377,41 +351,6 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
@@ -470,7 +409,7 @@ def set_default_attn_processor(self):
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+ r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -508,11 +447,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -532,11 +467,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py
index c275e16744f4..0ada264417dd 100644
--- a/src/diffusers/models/unets/unet_i2vgen_xl.py
+++ b/src/diffusers/models/unets/unet_i2vgen_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,17 +16,15 @@
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import logging
from ..activations import get_activation
-from ..attention import Attention, FeedForward
+from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
FusedAttnProcessor2_0,
@@ -94,7 +92,7 @@ def forward(
return hidden_states
-class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+class I2VGenXLUNet(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and
returns a sample-shaped output.
@@ -154,7 +152,7 @@ def __init__(
# of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This
# is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below.
# This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it
- # without running proper depcrecation cycles for the {down,mid,up} blocks which are a
+ # without running proper deprecation cycles for the {down,mid,up} blocks which are a
# part of the public API.
num_attention_heads = attention_head_dim
@@ -314,66 +312,6 @@ def __init__(
self.conv_act = get_activation("silu")
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
@@ -434,7 +372,7 @@ def set_default_attn_processor(self):
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1, s2, b1, b2):
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+ r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -472,11 +410,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -496,11 +430,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py
index 73bf0020b481..13f4641a4c50 100644
--- a/src/diffusers/models/unets/unet_kandinsky3.py
+++ b/src/diffusers/models/unets/unet_kandinsky3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,15 +13,15 @@
# limitations under the License.
from dataclasses import dataclass
-from typing import Dict, Tuple, Union
+from typing import Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
-from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
+from ..attention import AttentionMixin
+from ..attention_processor import Attention, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
@@ -46,7 +46,7 @@ def forward(self, x):
return x
-class Kandinsky3UNet(ModelMixin, ConfigMixin):
+class Kandinsky3UNet(ModelMixin, AttentionMixin, ConfigMixin):
@register_to_config
def __init__(
self,
@@ -55,7 +55,7 @@ def __init__(
groups: int = 32,
attention_head_dim: int = 64,
layers_per_block: Union[int, Tuple[int]] = 3,
- block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
+ block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
cross_attention_dim: Union[int, Tuple[int]] = 4096,
encoder_hid_dim: int = 4096,
):
@@ -141,64 +141,6 @@ def __init__(
self.conv_act_out = nn.SiLU()
self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
- @property
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "set_processor"):
- processors[f"{name}.processor"] = module.processor
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py
index bd83024c9b7c..5a93541501d3 100644
--- a/src/diffusers/models/unets/unet_motion_model.py
+++ b/src/diffusers/models/unets/unet_motion_model.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,18 +18,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging
from ...utils.torch_utils import apply_freeu
-from ..attention import BasicTransformerBlock
+from ..attention import AttentionMixin, BasicTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
AttnProcessor2_0,
@@ -1196,7 +1194,7 @@ def forward(self, sample):
pass
-class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
r"""
A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
sample shaped output.
@@ -1755,66 +1753,6 @@ def save_motion_modules(
**kwargs,
)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
"""
Sets the attention processor to use [feed forward
@@ -1873,7 +1811,7 @@ def set_default_attn_processor(self) -> None:
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+ r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -1911,11 +1849,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1935,11 +1869,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py
index 059a6e807c8e..c0cd5fbdd489 100644
--- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py
+++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -7,7 +7,8 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, logging
-from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
+from ..attention import AttentionMixin
+from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttnProcessor
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
@@ -29,7 +30,7 @@ class UNetSpatioTemporalConditionOutput(BaseOutput):
sample: torch.Tensor = None
-class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+class UNetSpatioTemporalConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
returns a sample shaped output.
@@ -73,25 +74,25 @@ def __init__(
sample_size: Optional[int] = None,
in_channels: int = 8,
out_channels: int = 4,
- down_block_types: Tuple[str] = (
+ down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
- up_block_types: Tuple[str] = (
+ up_block_types: Tuple[str, ...] = (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
addition_time_embed_dim: int = 256,
projection_class_embeddings_input_dim: int = 768,
layers_per_block: Union[int, Tuple[int]] = 2,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
- num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
+ num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
num_frames: int = 25,
):
super().__init__()
@@ -245,68 +246,6 @@ def __init__(
padding=1,
)
- @property
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(
- name: str,
- module: torch.nn.Module,
- processors: Dict[str, AttentionProcessor],
- ):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py
index f57754435fdc..23d358c1bf51 100644
--- a/src/diffusers/models/unets/unet_stable_cascade.py
+++ b/src/diffusers/models/unets/unet_stable_cascade.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -145,10 +145,10 @@ def __init__(
timestep_ratio_embedding_dim: int = 64,
patch_size: int = 1,
conditioning_dim: int = 2048,
- block_out_channels: Tuple[int] = (2048, 2048),
- num_attention_heads: Tuple[int] = (32, 32),
- down_num_layers_per_block: Tuple[int] = (8, 24),
- up_num_layers_per_block: Tuple[int] = (24, 8),
+ block_out_channels: Tuple[int, ...] = (2048, 2048),
+ num_attention_heads: Tuple[int, ...] = (32, 32),
+ down_num_layers_per_block: Tuple[int, ...] = (8, 24),
+ up_num_layers_per_block: Tuple[int, ...] = (24, 8),
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
1,
1,
@@ -167,7 +167,7 @@ def __init__(
kernel_size=3,
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
self_attn: Union[bool, Tuple[bool]] = True,
- timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
+ timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
switch_level: Optional[Tuple[bool]] = None,
):
"""
diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py
index 94b39c84f055..4c99ef88ca19 100644
--- a/src/diffusers/models/unets/uvit_2d.py
+++ b/src/diffusers/models/unets/uvit_2d.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Union
import torch
import torch.nn.functional as F
@@ -22,11 +21,10 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
-from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
+from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
@@ -36,7 +34,7 @@
from ..resnet import Downsample2D, Upsample2D
-class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
_supports_gradient_checkpointing = True
@register_to_config
@@ -209,66 +207,6 @@ def layer_(*args):
return logits
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py
index af04ae4b93cf..8a47c69f1264 100644
--- a/src/diffusers/models/upsampling.py
+++ b/src/diffusers/models/upsampling.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -358,7 +358,7 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
class CogVideoXUpsample3D(nn.Module):
r"""
- A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper release.
Args:
in_channels (`int`):
diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py
index 5027f4230e3b..5aad386a89e8 100644
--- a/src/diffusers/models/vae_flax.py
+++ b/src/diffusers/models/vae_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,10 +25,13 @@
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
-from ..utils import BaseOutput
+from ..utils import BaseOutput, logging
from .modeling_flax_utils import FlaxModelMixin
+logger = logging.get_logger(__name__)
+
+
@flax.struct.dataclass
class FlaxDecoderOutput(BaseOutput):
"""
@@ -73,6 +76,10 @@ class FlaxUpsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
@@ -107,6 +114,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
@@ -149,6 +161,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
@@ -221,6 +238,11 @@ class FlaxAttentionBlock(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
@@ -302,6 +324,11 @@ class FlaxDownEncoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
@@ -359,6 +386,11 @@ class FlaxUpDecoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
@@ -413,6 +445,11 @@ class FlaxUNetMidBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
# there is always at least one resnet
@@ -495,8 +532,8 @@ class FlaxEncoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
- block_out_channels: Tuple[int] = (64,)
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
+ block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
@@ -504,6 +541,11 @@ class FlaxEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
block_out_channels = self.block_out_channels
# in
self.conv_in = nn.Conv(
@@ -608,14 +650,19 @@ class FlaxDecoder(nn.Module):
in_channels: int = 3
out_channels: int = 3
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
- block_out_channels: int = (64,)
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
+ block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 2
norm_num_groups: int = 32
act_fn: str = "silu"
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
block_out_channels = self.block_out_channels
# z to block_in
@@ -769,16 +816,16 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
The `dtype` of the parameters.
"""
in_channels: int = 3
out_channels: int = 3
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
- block_out_channels: Tuple[int] = (64,)
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
+ block_out_channels: Tuple[int, ...] = (64,)
layers_per_block: int = 1
act_fn: str = "silu"
latent_channels: int = 4
@@ -788,6 +835,11 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,
diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py
index f946e4634476..63dee5eec554 100644
--- a/src/diffusers/models/vq_model.py
+++ b/src/diffusers/models/vq_model.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py
new file mode 100644
index 000000000000..252b9f33dfe8
--- /dev/null
+++ b/src/diffusers/modular_pipelines/__init__.py
@@ -0,0 +1,104 @@
+from typing import TYPE_CHECKING
+
+from ..utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+ logging,
+)
+
+
+logger = logging.get_logger(__name__)
+logger.warning(
+ "Modular Diffusers is currently an experimental feature under active development. The API is subject to breaking changes in future releases."
+)
+
+# These modules contain pipelines from multiple libraries/frameworks
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils import dummy_pt_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_pt_objects))
+else:
+ _import_structure["modular_pipeline"] = [
+ "ModularPipelineBlocks",
+ "ModularPipeline",
+ "AutoPipelineBlocks",
+ "SequentialPipelineBlocks",
+ "LoopSequentialPipelineBlocks",
+ "PipelineState",
+ "BlockState",
+ ]
+ _import_structure["modular_pipeline_utils"] = [
+ "ComponentSpec",
+ "ConfigSpec",
+ "InputParam",
+ "OutputParam",
+ "InsertableDict",
+ ]
+ _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
+ _import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
+ _import_structure["flux"] = [
+ "FluxAutoBlocks",
+ "FluxModularPipeline",
+ "FluxKontextAutoBlocks",
+ "FluxKontextModularPipeline",
+ ]
+ _import_structure["qwenimage"] = [
+ "QwenImageAutoBlocks",
+ "QwenImageModularPipeline",
+ "QwenImageEditModularPipeline",
+ "QwenImageEditAutoBlocks",
+ "QwenImageEditPlusModularPipeline",
+ "QwenImageEditPlusAutoBlocks",
+ ]
+ _import_structure["components_manager"] = ["ComponentsManager"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ..utils.dummy_pt_objects import * # noqa F403
+ else:
+ from .components_manager import ComponentsManager
+ from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
+ from .modular_pipeline import (
+ AutoPipelineBlocks,
+ BlockState,
+ LoopSequentialPipelineBlocks,
+ ModularPipeline,
+ ModularPipelineBlocks,
+ PipelineState,
+ SequentialPipelineBlocks,
+ )
+ from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
+ from .qwenimage import (
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditModularPipeline,
+ QwenImageEditPlusAutoBlocks,
+ QwenImageEditPlusModularPipeline,
+ QwenImageModularPipeline,
+ )
+ from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
+ from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py
new file mode 100644
index 000000000000..cb7e8fb73697
--- /dev/null
+++ b/src/diffusers/modular_pipelines/components_manager.py
@@ -0,0 +1,1077 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import time
+from collections import OrderedDict
+from itertools import combinations
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+
+from ..hooks import ModelHook
+from ..utils import (
+ is_accelerate_available,
+ logging,
+)
+from ..utils.torch_utils import get_device
+
+
+if is_accelerate_available():
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
+ from accelerate.state import PartialState
+ from accelerate.utils import send_to_device
+ from accelerate.utils.memory import clear_device_cache
+ from accelerate.utils.modeling import convert_file_size_to_int
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CustomOffloadHook(ModelHook):
+ """
+ A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
+ on the given device. Optionally offloads other models to the CPU before the forward pass is called.
+
+ Args:
+ execution_device(`str`, `int` or `torch.device`, *optional*):
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
+ GPU 0 if there is a GPU, and finally to the CPU.
+ """
+
+ no_grad = False
+
+ def __init__(
+ self,
+ execution_device: Optional[Union[str, int, torch.device]] = None,
+ other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
+ offload_strategy: Optional["AutoOffloadStrategy"] = None,
+ ):
+ self.execution_device = execution_device if execution_device is not None else PartialState().default_device
+ self.other_hooks = other_hooks
+ self.offload_strategy = offload_strategy
+ self.model_id = None
+
+ def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
+ self.offload_strategy = offload_strategy
+
+ def add_other_hook(self, hook: "UserCustomOffloadHook"):
+ """
+ Add a hook to the list of hooks to consider for offloading.
+ """
+ if self.other_hooks is None:
+ self.other_hooks = []
+ self.other_hooks.append(hook)
+
+ def init_hook(self, module):
+ return module.to("cpu")
+
+ def pre_forward(self, module, *args, **kwargs):
+ if module.device != self.execution_device:
+ if self.other_hooks is not None:
+ hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
+ # offload all other hooks
+ start_time = time.perf_counter()
+ if self.offload_strategy is not None:
+ hooks_to_offload = self.offload_strategy(
+ hooks=hooks_to_offload,
+ model_id=self.model_id,
+ model=module,
+ execution_device=self.execution_device,
+ )
+ end_time = time.perf_counter()
+ logger.info(
+ f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
+ )
+
+ for hook in hooks_to_offload:
+ logger.info(
+ f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
+ )
+ hook.offload()
+
+ if hooks_to_offload:
+ clear_device_cache()
+ module.to(self.execution_device)
+ return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
+
+
+class UserCustomOffloadHook:
+ """
+ A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
+ the hook or remove it entirely.
+ """
+
+ def __init__(self, model_id, model, hook):
+ self.model_id = model_id
+ self.model = model
+ self.hook = hook
+
+ def offload(self):
+ self.hook.init_hook(self.model)
+
+ def attach(self):
+ add_hook_to_module(self.model, self.hook)
+ self.hook.model_id = self.model_id
+
+ def remove(self):
+ remove_hook_from_module(self.model)
+ self.hook.model_id = None
+
+ def add_other_hook(self, hook: "UserCustomOffloadHook"):
+ self.hook.add_other_hook(hook)
+
+
+def custom_offload_with_hook(
+ model_id: str,
+ model: torch.nn.Module,
+ execution_device: Union[str, int, torch.device] = None,
+ offload_strategy: Optional["AutoOffloadStrategy"] = None,
+):
+ hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
+ user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
+ user_hook.attach()
+ return user_hook
+
+
+# this is the class that user can customize to implement their own offload strategy
+class AutoOffloadStrategy:
+ """
+ Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
+ the available memory on the device.
+ """
+
+ # YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device
+ # the actual memory usage would be higher. But it's simpler this way, and can be tested
+ def __init__(self, memory_reserve_margin="3GB"):
+ self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
+
+ def __call__(self, hooks, model_id, model, execution_device):
+ if len(hooks) == 0:
+ return []
+
+ current_module_size = model.get_memory_footprint()
+
+ device_type = execution_device.type
+ device_module = getattr(torch, device_type, torch.cuda)
+ try:
+ mem_on_device = device_module.mem_get_info(execution_device.index)[0]
+ except AttributeError:
+ raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
+
+ mem_on_device = mem_on_device - self.memory_reserve_margin
+ if current_module_size < mem_on_device:
+ return []
+
+ min_memory_offload = current_module_size - mem_on_device
+ logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
+
+ # exlucde models that's not currently loaded on the device
+ module_sizes = dict(
+ sorted(
+ {hook.model_id: hook.model.get_memory_footprint() for hook in hooks}.items(),
+ key=lambda x: x[1],
+ reverse=True,
+ )
+ )
+
+ # YiYi/Dhruv TODO: sort smallest to largest, and offload in that order we would tend to keep the larger models on GPU more often
+ def search_best_candidate(module_sizes, min_memory_offload):
+ """
+ search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
+ minimum memory offload size. the combination of models should add up to the smallest modulesize that is
+ larger than `min_memory_offload`
+ """
+ model_ids = list(module_sizes.keys())
+ best_candidate = None
+ best_size = float("inf")
+ for r in range(1, len(model_ids) + 1):
+ for candidate_model_ids in combinations(model_ids, r):
+ candidate_size = sum(
+ module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
+ )
+ if candidate_size < min_memory_offload:
+ continue
+ else:
+ if best_candidate is None or candidate_size < best_size:
+ best_candidate = candidate_model_ids
+ best_size = candidate_size
+
+ return best_candidate
+
+ best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
+
+ if best_offload_model_ids is None:
+ # if no combination is found, meaning that we cannot meet the memory requirement, offload all models
+ logger.warning("no combination of models to offload to cpu is found, offloading all models")
+ hooks_to_offload = hooks
+ else:
+ hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
+
+ return hooks_to_offload
+
+
+# utils for display component info in a readable format
+# TODO: move to a different file
+def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
+ """Summarizes a dictionary by finding common prefixes that share the same value.
+
+ For a dictionary with dot-separated keys like: {
+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
+ 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
+ }
+
+ Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
+ 'down_blocks': [0.6], 'up_blocks': [0.3]
+ }
+ """
+ # First group by values - convert lists to tuples to make them hashable
+ value_to_keys = {}
+ for key, value in d.items():
+ value_tuple = tuple(value) if isinstance(value, list) else value
+ if value_tuple not in value_to_keys:
+ value_to_keys[value_tuple] = []
+ value_to_keys[value_tuple].append(key)
+
+ def find_common_prefix(keys: List[str]) -> str:
+ """Find the shortest common prefix among a list of dot-separated keys."""
+ if not keys:
+ return ""
+ if len(keys) == 1:
+ return keys[0]
+
+ # Split all keys into parts
+ key_parts = [k.split(".") for k in keys]
+
+ # Find how many initial parts are common
+ common_length = 0
+ for parts in zip(*key_parts):
+ if len(set(parts)) == 1: # All parts at this position are the same
+ common_length += 1
+ else:
+ break
+
+ if common_length == 0:
+ return ""
+
+ # Return the common prefix
+ return ".".join(key_parts[0][:common_length])
+
+ # Create summary by finding common prefixes for each value group
+ summary = {}
+ for value_tuple, keys in value_to_keys.items():
+ prefix = find_common_prefix(keys)
+ if prefix: # Only add if we found a common prefix
+ # Convert tuple back to list if it was originally a list
+ value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
+ summary[prefix] = value
+ else:
+ summary[""] = value # Use empty string if no common prefix
+
+ return summary
+
+
+class ComponentsManager:
+ """
+ A central registry and management system for model components across multiple pipelines.
+
+ [`ComponentsManager`] provides a unified way to register, track, and reuse model components (like UNet, VAE, text
+ encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
+ management, and component organization.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+
+ Example:
+ ```python
+ from diffusers import ComponentsManager
+
+ # Create a components manager
+ cm = ComponentsManager()
+
+ # Add components
+ cm.add("unet", unet_model, collection="sdxl")
+ cm.add("vae", vae_model, collection="sdxl")
+
+ # Enable auto offloading
+ cm.enable_auto_cpu_offload()
+
+ # Retrieve components
+ unet = cm.get_one(name="unet", collection="sdxl")
+ ```
+ """
+
+ _available_info_fields = [
+ "model_id",
+ "added_time",
+ "collection",
+ "class_name",
+ "size_gb",
+ "adapters",
+ "has_hook",
+ "execution_device",
+ "ip_adapter",
+ ]
+
+ def __init__(self):
+ self.components = OrderedDict()
+ # YiYi TODO: can remove once confirm we don't need this in mellon
+ self.added_time = OrderedDict() # Store when components were added
+ self.collections = OrderedDict() # collection_name -> set of component_names
+ self.model_hooks = None
+ self._auto_offload_enabled = False
+
+ def _lookup_ids(
+ self,
+ name: Optional[str] = None,
+ collection: Optional[str] = None,
+ load_id: Optional[str] = None,
+ components: Optional[OrderedDict] = None,
+ ):
+ """
+ Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of
+ component_ids
+ """
+ if components is None:
+ components = self.components
+
+ if name:
+ ids_by_name = set()
+ for component_id, component in components.items():
+ comp_name = self._id_to_name(component_id)
+ if comp_name == name:
+ ids_by_name.add(component_id)
+ else:
+ ids_by_name = set(components.keys())
+ if collection:
+ ids_by_collection = set()
+ for component_id, component in components.items():
+ if component_id in self.collections[collection]:
+ ids_by_collection.add(component_id)
+ else:
+ ids_by_collection = set(components.keys())
+ if load_id:
+ ids_by_load_id = set()
+ for name, component in components.items():
+ if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
+ ids_by_load_id.add(name)
+ else:
+ ids_by_load_id = set(components.keys())
+
+ ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
+ return ids
+
+ @staticmethod
+ def _id_to_name(component_id: str):
+ return "_".join(component_id.split("_")[:-1])
+
+ def add(self, name: str, component: Any, collection: Optional[str] = None):
+ """
+ Add a component to the ComponentsManager.
+
+ Args:
+ name (str): The name of the component
+ component (Any): The component to add
+ collection (Optional[str]): The collection to add the component to
+
+ Returns:
+ str: The unique component ID, which is generated as "{name}_{id(component)}" where
+ id(component) is Python's built-in unique identifier for the object
+ """
+ component_id = f"{name}_{id(component)}"
+ is_new_component = True
+
+ # check for duplicated components
+ for comp_id, comp in self.components.items():
+ if comp == component:
+ comp_name = self._id_to_name(comp_id)
+ if comp_name == name:
+ logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
+ component_id = comp_id
+ is_new_component = False
+ break
+ else:
+ logger.warning(
+ f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
+ f"To remove a duplicate, call `components_manager.remove('')`."
+ )
+
+ # check for duplicated load_id and warn (we do not delete for you)
+ if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
+ components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
+ components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
+
+ if components_with_same_load_id:
+ existing = ", ".join(components_with_same_load_id)
+ logger.warning(
+ f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
+ f"To remove a duplicate, call `components_manager.remove('')`."
+ )
+
+ # add component to components manager
+ self.components[component_id] = component
+ self.added_time[component_id] = time.time()
+
+ if collection:
+ if collection not in self.collections:
+ self.collections[collection] = set()
+ if component_id not in self.collections[collection]:
+ comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
+ for comp_id in comp_ids_in_collection:
+ logger.warning(
+ f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
+ )
+ # remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager)
+ self.remove_from_collection(comp_id, collection)
+
+ self.collections[collection].add(component_id)
+ logger.info(
+ f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
+ )
+ else:
+ logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
+
+ if self._auto_offload_enabled and is_new_component:
+ self.enable_auto_cpu_offload(self._auto_offload_device)
+
+ return component_id
+
+ def remove_from_collection(self, component_id: str, collection: str):
+ """
+ Remove a component from a collection.
+ """
+ if collection not in self.collections:
+ logger.warning(f"Collection '{collection}' not found in ComponentsManager")
+ return
+ if component_id not in self.collections[collection]:
+ logger.warning(f"Component '{component_id}' not found in collection '{collection}'")
+ return
+ # remove from the collection
+ self.collections[collection].remove(component_id)
+ # check if this component is in any other collection
+ comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps]
+ if not comp_colls: # only if no other collection contains this component, remove it
+ logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager")
+ self.remove(component_id)
+
+ def remove(self, component_id: str = None):
+ """
+ Remove a component from the ComponentsManager.
+
+ Args:
+ component_id (str): The ID of the component to remove
+ """
+ if component_id not in self.components:
+ logger.warning(f"Component '{component_id}' not found in ComponentsManager")
+ return
+
+ component = self.components.pop(component_id)
+ self.added_time.pop(component_id)
+
+ for collection in self.collections:
+ if component_id in self.collections[collection]:
+ self.collections[collection].remove(component_id)
+
+ if self._auto_offload_enabled:
+ self.enable_auto_cpu_offload(self._auto_offload_device)
+ else:
+ if isinstance(component, torch.nn.Module):
+ component.to("cpu")
+ del component
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ if torch.xpu.is_available():
+ torch.xpu.empty_cache()
+
+ # YiYi TODO: rename to search_components for now, may remove this method
+ def search_components(
+ self,
+ names: Optional[str] = None,
+ collection: Optional[str] = None,
+ load_id: Optional[str] = None,
+ return_dict_with_names: bool = True,
+ ):
+ """
+ Search components by name with simple pattern matching. Optionally filter by collection or load_id.
+
+ Args:
+ names: Component name(s) or pattern(s)
+ Patterns:
+ - "unet" : match any component with base name "unet" (e.g., unet_123abc)
+ - "!unet" : everything except components with base name "unet"
+ - "unet*" : anything with base name starting with "unet"
+ - "!unet*" : anything with base name NOT starting with "unet"
+ - "*unet*" : anything with base name containing "unet"
+ - "!*unet*" : anything with base name NOT containing "unet"
+ - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
+ - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
+ - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
+ collection: Optional collection to filter by
+ load_id: Optional load_id to filter by
+ return_dict_with_names:
+ If True, returns a dictionary with component names as keys, throw an error if
+ multiple components with the same name are found If False, returns a dictionary
+ with component IDs as keys
+
+ Returns:
+ Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping
+ component IDs to components if return_dict_with_names=False
+ """
+
+ # select components based on collection and load_id filters
+ selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
+ components = {k: self.components[k] for k in selected_ids}
+
+ def get_return_dict(components, return_dict_with_names):
+ """
+ Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary
+ mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component
+ names are found when return_dict_with_names=True
+ """
+ if return_dict_with_names:
+ dict_to_return = {}
+ for comp_id, comp in components.items():
+ comp_name = self._id_to_name(comp_id)
+ if comp_name in dict_to_return:
+ raise ValueError(
+ f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
+ )
+ dict_to_return[comp_name] = comp
+ return dict_to_return
+ else:
+ return components
+
+ # if no names are provided, return the filtered components as it is
+ if names is None:
+ return get_return_dict(components, return_dict_with_names)
+
+ # if names is not a string, raise an error
+ elif not isinstance(names, str):
+ raise ValueError(f"Invalid type for `names: {type(names)}, only support string")
+
+ # Create mapping from component_id to base_name for components to be used for pattern matching
+ base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()}
+
+ # Helper function to check if a component matches a pattern based on its base name
+ def matches_pattern(component_id, pattern, exact_match=False):
+ """
+ Helper function to check if a component matches a pattern based on its base name.
+
+ Args:
+ component_id: The component ID to check
+ pattern: The pattern to match against
+ exact_match: If True, only exact matches to base_name are considered
+ """
+ base_name = base_names[component_id]
+
+ # Exact match with base name
+ if exact_match:
+ return pattern == base_name
+
+ # Prefix match (ends with *)
+ elif pattern.endswith("*"):
+ prefix = pattern[:-1]
+ return base_name.startswith(prefix)
+
+ # Contains match (starts with *)
+ elif pattern.startswith("*"):
+ search = pattern[1:-1] if pattern.endswith("*") else pattern[1:]
+ return search in base_name
+
+ # Exact match (no wildcards)
+ else:
+ return pattern == base_name
+
+ # Check if this is a "not" pattern
+ is_not_pattern = names.startswith("!")
+ if is_not_pattern:
+ names = names[1:] # Remove the ! prefix
+
+ # Handle OR patterns (containing |)
+ if "|" in names:
+ terms = names.split("|")
+ matches = {}
+
+ for comp_id, comp in components.items():
+ # For OR patterns with exact names (no wildcards), we do exact matching on base names
+ exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms)
+
+ # Check if any of the terms match this component
+ should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
+
+ # Flip the decision if this is a NOT pattern
+ if is_not_pattern:
+ should_include = not should_include
+
+ if should_include:
+ matches[comp_id] = comp
+
+ log_msg = "NOT " if is_not_pattern else ""
+ match_type = "exactly matching" if exact_match else "matching any of patterns"
+ logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
+
+ # Try exact match with a base name
+ elif any(names == base_name for base_name in base_names.values()):
+ # Find all components with this base name
+ matches = {
+ comp_id: comp
+ for comp_id, comp in components.items()
+ if (base_names[comp_id] == names) != is_not_pattern
+ }
+
+ if is_not_pattern:
+ logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
+ else:
+ logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
+
+ # Prefix match (ends with *)
+ elif names.endswith("*"):
+ prefix = names[:-1]
+ matches = {
+ comp_id: comp
+ for comp_id, comp in components.items()
+ if base_names[comp_id].startswith(prefix) != is_not_pattern
+ }
+ if is_not_pattern:
+ logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
+ else:
+ logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
+
+ # Contains match (starts with *)
+ elif names.startswith("*"):
+ search = names[1:-1] if names.endswith("*") else names[1:]
+ matches = {
+ comp_id: comp
+ for comp_id, comp in components.items()
+ if (search in base_names[comp_id]) != is_not_pattern
+ }
+ if is_not_pattern:
+ logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
+ else:
+ logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
+
+ # Substring match (no wildcards, but not an exact component name)
+ elif any(names in base_name for base_name in base_names.values()):
+ matches = {
+ comp_id: comp
+ for comp_id, comp in components.items()
+ if (names in base_names[comp_id]) != is_not_pattern
+ }
+ if is_not_pattern:
+ logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
+ else:
+ logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
+
+ else:
+ raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
+
+ if not matches:
+ raise ValueError(f"No components found matching pattern '{names}'")
+
+ return get_return_dict(matches, return_dict_with_names)
+
+ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
+ """
+ Enable automatic CPU offloading for all components.
+
+ The algorithm works as follows:
+ 1. All models start on CPU by default
+ 2. When a model's forward pass is called, it's moved to the execution device
+ 3. If there's insufficient memory, other models on the device are moved back to CPU
+ 4. The system tries to offload the smallest combination of models that frees enough memory
+ 5. Models stay on the execution device until another model needs memory and forces them off
+
+ Args:
+ device (Union[str, int, torch.device]): The execution device where models are moved for forward passes
+ memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
+ memory to keep free on the device to avoid running out of memory during model
+ execution (e.g., for intermediate activations, gradients, etc.)
+ """
+ if not is_accelerate_available():
+ raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
+
+ # TODO: add a warning if mem_get_info isn't available on `device`.
+
+ for name, component in self.components.items():
+ if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
+ remove_hook_from_module(component, recurse=True)
+
+ self.disable_auto_cpu_offload()
+ offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
+ if device is None:
+ device = get_device()
+ device = torch.device(device)
+ if device.index is None:
+ device = torch.device(f"{device.type}:{0}")
+ all_hooks = []
+ for name, component in self.components.items():
+ if isinstance(component, torch.nn.Module):
+ hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
+ all_hooks.append(hook)
+
+ for hook in all_hooks:
+ other_hooks = [h for h in all_hooks if h is not hook]
+ for other_hook in other_hooks:
+ if other_hook.hook.execution_device == hook.hook.execution_device:
+ hook.add_other_hook(other_hook)
+
+ self.model_hooks = all_hooks
+ self._auto_offload_enabled = True
+ self._auto_offload_device = device
+
+ def disable_auto_cpu_offload(self):
+ """
+ Disable automatic CPU offloading for all components.
+ """
+ if self.model_hooks is None:
+ self._auto_offload_enabled = False
+ return
+
+ for hook in self.model_hooks:
+ hook.offload()
+ hook.remove()
+ if self.model_hooks:
+ clear_device_cache()
+ self.model_hooks = None
+ self._auto_offload_enabled = False
+
+ # YiYi TODO: (1) add quantization info
+ def get_model_info(
+ self,
+ component_id: str,
+ fields: Optional[Union[str, List[str]]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ """Get comprehensive information about a component.
+
+ Args:
+ component_id (str): Name of the component to get info for
+ fields (Optional[Union[str, List[str]]]):
+ Field(s) to return. Can be a string for single field or list of fields. If None, uses the
+ available_info_fields setting.
+
+ Returns:
+ Dictionary containing requested component metadata. If fields is specified, returns only those fields.
+ Otherwise, returns all fields.
+ """
+ if component_id not in self.components:
+ raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
+
+ component = self.components[component_id]
+
+ # Validate fields if specified
+ if fields is not None:
+ if isinstance(fields, str):
+ fields = [fields]
+ for field in fields:
+ if field not in self._available_info_fields:
+ raise ValueError(f"Field '{field}' not found in available_info_fields")
+
+ # Build complete info dict first
+ info = {
+ "model_id": component_id,
+ "added_time": self.added_time[component_id],
+ "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps])
+ or None,
+ }
+
+ # Additional info for torch.nn.Module components
+ if isinstance(component, torch.nn.Module):
+ # Check for hook information
+ has_hook = hasattr(component, "_hf_hook")
+ execution_device = None
+ if has_hook and hasattr(component._hf_hook, "execution_device"):
+ execution_device = component._hf_hook.execution_device
+
+ info.update(
+ {
+ "class_name": component.__class__.__name__,
+ "size_gb": component.get_memory_footprint() / (1024**3),
+ "adapters": None, # Default to None
+ "has_hook": has_hook,
+ "execution_device": execution_device,
+ }
+ )
+
+ # Get adapters if applicable
+ if hasattr(component, "peft_config"):
+ info["adapters"] = list(component.peft_config.keys())
+
+ # Check for IP-Adapter scales
+ if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
+ processors = copy.deepcopy(component.attn_processors)
+ # First check if any processor is an IP-Adapter
+ processor_types = [v.__class__.__name__ for v in processors.values()]
+ if any("IPAdapter" in ptype for ptype in processor_types):
+ # Then get scales only from IP-Adapter processors
+ scales = {
+ k: v.scale
+ for k, v in processors.items()
+ if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
+ }
+ if scales:
+ info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
+
+ # If fields specified, filter info
+ if fields is not None:
+ return {k: v for k, v in info.items() if k in fields}
+ else:
+ return info
+
+ # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
+ def __repr__(self):
+ # Handle empty components case
+ if not self.components:
+ return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50
+
+ # Extract load_id if available
+ def get_load_id(component):
+ if hasattr(component, "_diffusers_load_id"):
+ return component._diffusers_load_id
+ return "N/A"
+
+ # Format device info compactly
+ def format_device(component, info):
+ if not info["has_hook"]:
+ return str(getattr(component, "device", "N/A"))
+ else:
+ device = str(getattr(component, "device", "N/A"))
+ exec_device = str(info["execution_device"] or "N/A")
+ return f"{device}({exec_device})"
+
+ # Get max length of load_ids for models
+ load_ids = [
+ get_load_id(component)
+ for component in self.components.values()
+ if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
+ ]
+ max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
+
+ # Get all collections for each component
+ component_collections = {}
+ for name in self.components.keys():
+ component_collections[name] = []
+ for coll, comps in self.collections.items():
+ if name in comps:
+ component_collections[name].append(coll)
+ if not component_collections[name]:
+ component_collections[name] = ["N/A"]
+
+ # Find the maximum collection name length
+ all_collections = [coll for colls in component_collections.values() for coll in colls]
+ max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
+
+ col_widths = {
+ "id": max(15, max(len(name) for name in self.components.keys())),
+ "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
+ "device": 20,
+ "dtype": 15,
+ "size": 10,
+ "load_id": max_load_id_len,
+ "collection": max_collection_len,
+ }
+
+ # Create the header lines
+ sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
+ dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
+
+ output = "Components:\n" + sep_line
+
+ # Separate components into models and others
+ models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
+ others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
+
+ # Models section
+ if models:
+ output += "Models:\n" + dash_line
+ # Column headers
+ output += f"{'Name_ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
+ output += f"{'Device: act(exec)':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
+ output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
+ output += dash_line
+
+ # Model entries
+ for name, component in models.items():
+ info = self.get_model_info(name)
+ device_str = format_device(component, info)
+ dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
+ load_id = get_load_id(component)
+
+ # Print first collection on the main line
+ first_collection = component_collections[name][0] if component_collections[name] else "N/A"
+
+ output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
+ output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
+ output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
+
+ # Print additional collections on separate lines if they exist
+ for i in range(1, len(component_collections[name])):
+ collection = component_collections[name][i]
+ output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | "
+ output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
+ output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
+
+ output += dash_line
+
+ # Other components section
+ if others:
+ if models: # Add extra newline if we had models section
+ output += "\n"
+ output += "Other Components:\n" + dash_line
+ # Column headers for other components
+ output += f"{'ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | Collection\n"
+ output += dash_line
+
+ # Other component entries
+ for name, component in others.items():
+ info = self.get_model_info(name)
+
+ # Print first collection on the main line
+ first_collection = component_collections[name][0] if component_collections[name] else "N/A"
+
+ output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
+
+ # Print additional collections on separate lines if they exist
+ for i in range(1, len(component_collections[name])):
+ collection = component_collections[name][i]
+ output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | {collection}\n"
+
+ output += dash_line
+
+ # Add additional component info
+ output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
+ for name in self.components:
+ info = self.get_model_info(name)
+ if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
+ output += f"\n{name}:\n"
+ if info.get("adapters") is not None:
+ output += f" Adapters: {info['adapters']}\n"
+ if info.get("ip_adapter"):
+ output += " IP-Adapter: Enabled\n"
+
+ return output
+
+ def get_one(
+ self,
+ component_id: Optional[str] = None,
+ name: Optional[str] = None,
+ collection: Optional[str] = None,
+ load_id: Optional[str] = None,
+ ) -> Any:
+ """
+ Get a single component by either:
+ - searching name (pattern matching), collection, or load_id.
+ - passing in a component_id
+ Raises an error if multiple components match or none are found.
+
+ Args:
+ component_id (Optional[str]): Optional component ID to get
+ name (Optional[str]): Component name or pattern
+ collection (Optional[str]): Optional collection to filter by
+ load_id (Optional[str]): Optional load_id to filter by
+
+ Returns:
+ A single component
+
+ Raises:
+ ValueError: If no components match or multiple components match
+ """
+
+ if component_id is not None and (name is not None or collection is not None or load_id is not None):
+ raise ValueError("If searching by component_id, do not pass name, collection, or load_id")
+
+ # search by component_id
+ if component_id is not None:
+ if component_id not in self.components:
+ raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
+ return self.components[component_id]
+ # search with name/collection/load_id
+ results = self.search_components(name, collection, load_id)
+
+ if not results:
+ raise ValueError(f"No components found matching '{name}'")
+
+ if len(results) > 1:
+ raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
+
+ return next(iter(results.values()))
+
+ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
+ """
+ Get component IDs by a list of names, optionally filtered by collection.
+
+ Args:
+ names (Union[str, List[str]]): List of component names
+ collection (Optional[str]): Optional collection to filter by
+
+ Returns:
+ List[str]: List of component IDs
+ """
+ ids = set()
+ if not isinstance(names, list):
+ names = [names]
+ for name in names:
+ ids.update(self._lookup_ids(name=name, collection=collection))
+ return list(ids)
+
+ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
+ """
+ Get components by a list of IDs.
+
+ Args:
+ ids (List[str]):
+ List of component IDs
+ return_dict_with_names (Optional[bool]):
+ Whether to return a dictionary with component names as keys:
+
+ Returns:
+ Dict[str, Any]: Dictionary of components.
+ - If return_dict_with_names=True, keys are component names.
+ - If return_dict_with_names=False, keys are component IDs.
+
+ Raises:
+ ValueError: If duplicate component names are found in the search results when return_dict_with_names=True
+ """
+ components = {id: self.components[id] for id in ids}
+
+ if return_dict_with_names:
+ dict_to_return = {}
+ for comp_id, comp in components.items():
+ comp_name = self._id_to_name(comp_id)
+ if comp_name in dict_to_return:
+ raise ValueError(
+ f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
+ )
+ dict_to_return[comp_name] = comp
+ return dict_to_return
+ else:
+ return components
+
+ def get_components_by_names(self, names: List[str], collection: Optional[str] = None):
+ """
+ Get components by a list of names, optionally filtered by collection.
+
+ Args:
+ names (List[str]): List of component names
+ collection (Optional[str]): Optional collection to filter by
+
+ Returns:
+ Dict[str, Any]: Dictionary of components with component names as keys
+
+ Raises:
+ ValueError: If duplicate component names are found in the search results
+ """
+ ids = self.get_ids(names, collection)
+ return self.get_components_by_ids(ids)
diff --git a/src/diffusers/modular_pipelines/flux/__init__.py b/src/diffusers/modular_pipelines/flux/__init__.py
new file mode 100644
index 000000000000..ec00986611c8
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/__init__.py
@@ -0,0 +1,75 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["encoders"] = ["FluxTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "AUTO_BLOCKS",
+ "AUTO_BLOCKS_KONTEXT",
+ "FLUX_KONTEXT_BLOCKS",
+ "TEXT2IMAGE_BLOCKS",
+ "FluxAutoBeforeDenoiseStep",
+ "FluxAutoBlocks",
+ "FluxAutoDecodeStep",
+ "FluxAutoDenoiseStep",
+ "FluxKontextAutoBlocks",
+ "FluxKontextAutoDenoiseStep",
+ "FluxKontextBeforeDenoiseStep",
+ ]
+ _import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .encoders import FluxTextEncoderStep
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ AUTO_BLOCKS,
+ AUTO_BLOCKS_KONTEXT,
+ FLUX_KONTEXT_BLOCKS,
+ TEXT2IMAGE_BLOCKS,
+ FluxAutoBeforeDenoiseStep,
+ FluxAutoBlocks,
+ FluxAutoDecodeStep,
+ FluxAutoDenoiseStep,
+ FluxKontextAutoBlocks,
+ FluxKontextAutoDenoiseStep,
+ FluxKontextBeforeDenoiseStep,
+ )
+ from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py
new file mode 100644
index 000000000000..daffec986535
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/before_denoise.py
@@ -0,0 +1,619 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+from ...pipelines import FluxPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging
+from ...utils.torch_utils import randn_tensor
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import FluxModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def _get_initial_timesteps_and_optionals(
+ transformer,
+ scheduler,
+ batch_size,
+ height,
+ width,
+ vae_scale_factor,
+ num_inference_steps,
+ guidance_scale,
+ sigmas,
+ device,
+):
+ image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
+ sigmas = None
+ mu = calculate_shift(
+ image_seq_len,
+ scheduler.config.get("base_image_seq_len", 256),
+ scheduler.config.get("max_image_seq_len", 4096),
+ scheduler.config.get("base_shift", 0.5),
+ scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
+ if transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(batch_size)
+ else:
+ guidance = None
+
+ return timesteps, num_inference_steps, sigmas, guidance
+
+
+class FluxSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the scheduler's timesteps for inference"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ InputParam("guidance_scale", default=3.5),
+ InputParam("latents", type_hint=torch.Tensor),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("height", type_hint=int),
+ InputParam("width", type_hint=int),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+
+ scheduler = components.scheduler
+ transformer = components.transformer
+
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+ timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
+ transformer,
+ scheduler,
+ batch_size,
+ block_state.height,
+ block_state.width,
+ components.vae_scale_factor,
+ block_state.num_inference_steps,
+ block_state.guidance_scale,
+ block_state.sigmas,
+ block_state.device,
+ )
+ block_state.timesteps = timesteps
+ block_state.num_inference_steps = num_inference_steps
+ block_state.sigmas = sigmas
+ block_state.guidance = guidance
+
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ components.scheduler.set_begin_index(0)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the scheduler's timesteps for inference"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ InputParam("strength", default=0.6),
+ InputParam("guidance_scale", default=3.5),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("height", type_hint=int),
+ InputParam("width", type_hint=int),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
+ def get_timesteps(scheduler, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = scheduler.timesteps[t_start * scheduler.order :]
+ if hasattr(scheduler, "set_begin_index"):
+ scheduler.set_begin_index(t_start * scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+
+ scheduler = components.scheduler
+ transformer = components.transformer
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+ timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
+ transformer,
+ scheduler,
+ batch_size,
+ block_state.height,
+ block_state.width,
+ components.vae_scale_factor,
+ block_state.num_inference_steps,
+ block_state.guidance_scale,
+ block_state.sigmas,
+ block_state.device,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(
+ scheduler, num_inference_steps, block_state.strength, block_state.device
+ )
+ block_state.timesteps = timesteps
+ block_state.num_inference_steps = num_inference_steps
+ block_state.sigmas = sigmas
+ block_state.guidance = guidance
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxPrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return []
+
+ @property
+ def description(self) -> str:
+ return "Prepare latents step that prepares the latents for the text-to-image generation process"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("height", type_hint=int),
+ InputParam("width", type_hint=int),
+ InputParam("latents", type_hint=Optional[torch.Tensor]),
+ InputParam("num_images_per_prompt", type_hint=int, default=1),
+ InputParam("generator"),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
+ ),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or (
+ block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0
+ ):
+ logger.warning(
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}."
+ )
+
+ @staticmethod
+ def prepare_latents(
+ comp,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // (comp.vae_scale_factor * 2))
+ width = 2 * (int(width) // (comp.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # TODO: move packing latents code to a patchifier similar to Qwen
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+ block_state.device = components._execution_device
+ block_state.num_channels_latents = components.num_channels_latents
+
+ self.check_inputs(components, block_state)
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+ block_state.latents = self.prepare_latents(
+ components,
+ batch_size,
+ block_state.num_channels_latents,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.latents,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Step that adds noise to image latents for image-to-image. Should be run after `set_timesteps`,"
+ " `prepare_latents`. Both noise and image latents should already be patchified."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial random noised, can be generated in prepare latent step.",
+ ),
+ InputParam(
+ name="image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
+ ),
+ InputParam(
+ name="timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="initial_noise",
+ type_hint=torch.Tensor,
+ description="The initial random noised used for inpainting denoising.",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(image_latents, latents):
+ if image_latents.shape[0] != latents.shape[0]:
+ raise ValueError(
+ f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
+ )
+
+ if image_latents.ndim != 3:
+ raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(image_latents=block_state.image_latents, latents=block_state.latents)
+
+ # prepare latent timestep
+ latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
+
+ # make copy of initial_noise
+ block_state.initial_noise = block_state.latents
+
+ # scale noise
+ block_state.latents = components.scheduler.scale_noise(
+ block_state.image_latents, latent_timestep, block_state.latents
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="txt_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
+ ),
+ OutputParam(
+ name="img_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the image latents, used for RoPE calculation.",
+ ),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ prompt_embeds = block_state.prompt_embeds
+ device, dtype = prompt_embeds.device, prompt_embeds.dtype
+ block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
+ device=prompt_embeds.device, dtype=prompt_embeds.dtype
+ )
+
+ height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+ block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="image_height"),
+ InputParam(name="image_width"),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="prompt_embeds"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="txt_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
+ ),
+ OutputParam(
+ name="img_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the image latents, used for RoPE calculation.",
+ ),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ prompt_embeds = block_state.prompt_embeds
+ device, dtype = prompt_embeds.device, prompt_embeds.dtype
+ block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
+ device=prompt_embeds.device, dtype=prompt_embeds.dtype
+ )
+
+ img_ids = None
+ if (
+ getattr(block_state, "image_height", None) is not None
+ and getattr(block_state, "image_width", None) is not None
+ ):
+ image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
+ image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
+ img_ids = FluxPipeline._prepare_latent_image_ids(
+ None, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ img_ids[..., 0] = 1
+
+ height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+ latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
+
+ if img_ids is not None:
+ latent_ids = torch.cat([latent_ids, img_ids], dim=0)
+
+ block_state.img_ids = latent_ids
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/decoders.py b/src/diffusers/modular_pipelines/flux/decoders.py
new file mode 100644
index 000000000000..846549b1a376
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/decoders.py
@@ -0,0 +1,109 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL
+from ...utils import logging
+from ...video_processor import VaeImageProcessor
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+
+class FluxDecodeStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the denoised latents into images"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("output_type", default="pil"),
+ InputParam("height", default=1024),
+ InputParam("width", default=1024),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The denoised latents from the denoising step",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
+ description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ vae = components.vae
+
+ if not block_state.output_type == "latent":
+ latents = block_state.latents
+ latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
+ block_state.images = vae.decode(latents, return_dict=False)[0]
+ block_state.images = components.image_processor.postprocess(
+ block_state.images, output_type=block_state.output_type
+ )
+ else:
+ block_state.images = block_state.latents
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py
new file mode 100644
index 000000000000..5a769df1036d
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/denoise.py
@@ -0,0 +1,330 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple
+
+import torch
+
+from ...models import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging
+from ..modular_pipeline import (
+ BlockState,
+ LoopSequentialPipelineBlocks,
+ ModularPipelineBlocks,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import FluxModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class FluxLoopDenoiser(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("transformer", FluxTransformer2DModel)]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("joint_attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "guidance",
+ required=False,
+ type_hint=torch.Tensor,
+ description="Guidance scale as a tensor",
+ ),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Prompt embeddings",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pooled prompt embeddings",
+ ),
+ InputParam(
+ "txt_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from text sequence needed for RoPE",
+ ),
+ InputParam(
+ "img_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from image sequence needed for RoPE",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(
+ self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
+ ) -> PipelineState:
+ noise_pred = components.transformer(
+ hidden_states=block_state.latents,
+ timestep=t.flatten() / 1000,
+ guidance=block_state.guidance,
+ encoder_hidden_states=block_state.prompt_embeds,
+ pooled_projections=block_state.pooled_prompt_embeds,
+ joint_attention_kwargs=block_state.joint_attention_kwargs,
+ txt_ids=block_state.txt_ids,
+ img_ids=block_state.img_ids,
+ return_dict=False,
+ )[0]
+ block_state.noise_pred = noise_pred
+
+ return components, block_state
+
+
+class FluxKontextLoopDenoiser(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("transformer", FluxTransformer2DModel)]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents for Flux Kontext. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("joint_attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "image_latents",
+ type_hint=torch.Tensor,
+ description="Image latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "guidance",
+ required=False,
+ type_hint=torch.Tensor,
+ description="Guidance scale as a tensor",
+ ),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Prompt embeddings",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pooled prompt embeddings",
+ ),
+ InputParam(
+ "txt_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from text sequence needed for RoPE",
+ ),
+ InputParam(
+ "img_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from latent sequence needed for RoPE",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(
+ self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
+ ) -> PipelineState:
+ latents = block_state.latents
+ latent_model_input = latents
+ image_latents = block_state.image_latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
+
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ noise_pred = components.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=block_state.guidance,
+ encoder_hidden_states=block_state.prompt_embeds,
+ pooled_projections=block_state.pooled_prompt_embeds,
+ joint_attention_kwargs=block_state.joint_attention_kwargs,
+ txt_ids=block_state.txt_ids,
+ img_ids=block_state.img_ids,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+ block_state.noise_pred = noise_pred
+
+ return components, block_state
+
+
+class FluxLoopAfterDenoiser(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that update the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return []
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [InputParam("generator")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # Perform scheduler step using the predicted output
+ latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != latents_dtype:
+ block_state.latents = block_state.latents.to(latents_dtype)
+
+ return components, block_state
+
+
+class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ComponentSpec("transformer", FluxTransformer2DModel),
+ ]
+
+ @property
+ def loop_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxDenoiseStep(FluxDenoiseLoopWrapper):
+ block_classes = [FluxLoopDenoiser, FluxLoopAfterDenoiser]
+ block_names = ["denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `FluxLoopDenoiser`\n"
+ " - `FluxLoopAfterDenoiser`\n"
+ "This block supports both text2image and img2img tasks."
+ )
+
+
+class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper):
+ model_name = "flux-kontext"
+ block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser]
+ block_names = ["denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `FluxKontextLoopDenoiser`\n"
+ " - `FluxLoopAfterDenoiser`\n"
+ "This block supports both text2image and img2img tasks."
+ )
diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py
new file mode 100644
index 000000000000..f0314d4771b0
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/encoders.py
@@ -0,0 +1,483 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import List, Optional, Union
+
+import regex as re
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
+from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL
+from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import FluxModularPipeline
+
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
+
+ image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
+
+ return image_latents
+
+
+class FluxProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam(name="processed_image")]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("`resized_image` and `image` cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Image preprocess step for Flux Kontext. The preprocessed image goes to the VAE.\n"
+ "Kontext works as a T2I model, too, in case no input image is provided."
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam(name="processed_image")]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState):
+ from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
+
+ block_state = self.get_block_state(state)
+ images = block_state.image
+
+ if images is None:
+ block_state.processed_image = None
+
+ else:
+ multiple_of = components.image_processor.config.vae_scale_factor
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if is_valid_image(images):
+ images = [images]
+
+ img = images[0]
+ image_height, image_width = components.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ _auto_resize = block_state._auto_resize
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ images = components.image_processor.resize(images, image_height, image_width)
+ block_state.processed_image = components.image_processor.preprocess(images, image_height, image_width)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ def __init__(
+ self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"
+ ):
+ """Initialize a VAE encoder step for converting images to latent representations.
+
+ Both the input and output names are configurable so this block can be configured to process to different image
+ inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
+
+ Args:
+ input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
+ Examples: "processed_image" or "processed_control_image"
+ output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
+ Examples: "image_latents" or "control_image_latents"
+ sample_mode (str, optional): Sampling mode to be used.
+
+ Examples:
+ # Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep()
+
+ # Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep(
+ input_name="processed_control_image", output_name="control_image_latents"
+ )
+ """
+ self._image_input_name = input_name
+ self._image_latents_output_name = output_name
+ self.sample_mode = sample_mode
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [ComponentSpec("vae", AutoencoderKL)]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [InputParam(self._image_input_name), InputParam("generator")]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ self._image_latents_output_name,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ image = getattr(block_state, self._image_input_name)
+
+ if image is None:
+ setattr(block_state, self._image_latents_output_name, None)
+ else:
+ device = components._execution_device
+ dtype = components.vae.dtype
+ image = image.to(device=device, dtype=dtype)
+
+ # Encode image into latents
+ image_latents = encode_vae_image(
+ image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode
+ )
+ setattr(block_state, self._image_latents_output_name, image_latents)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxTextEncoderStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", CLIPTextModel),
+ ComponentSpec("tokenizer", CLIPTokenizer),
+ ComponentSpec("text_encoder_2", T5EncoderModel),
+ ComponentSpec("tokenizer_2", T5TokenizerFast),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("prompt"),
+ InputParam("prompt_2"),
+ InputParam("max_sequence_length", type_hint=int, default=512, required=False),
+ InputParam("joint_attention_kwargs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(block_state):
+ for prompt in [block_state.prompt, block_state.prompt_2]:
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @staticmethod
+ def _get_t5_prompt_embeds(
+ components, prompt: Union[str, List[str]], max_sequence_length: int, device: torch.device
+ ):
+ dtype = components.text_encoder_2.dtype
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if isinstance(components, TextualInversionLoaderMixin):
+ prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)
+
+ text_inputs = components.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ return prompt_embeds
+
+ @staticmethod
+ def _get_clip_prompt_embeds(components, prompt: Union[str, List[str]], device: torch.device):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if isinstance(components, TextualInversionLoaderMixin):
+ prompt = components.maybe_convert_prompt(prompt, components.tokenizer)
+
+ text_inputs = components.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=components.tokenizer.model_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ tokenizer_max_length = components.tokenizer.model_max_length
+ untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)
+
+ return prompt_embeds
+
+ @staticmethod
+ def encode_prompt(
+ components,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ device = device or components._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
+ components._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if components.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(components.text_encoder, lora_scale)
+ if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(components.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds(
+ components,
+ prompt=prompt,
+ device=device,
+ )
+ prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
+ components,
+ prompt=prompt_2,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if components.text_encoder is not None:
+ if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(components.text_encoder, lora_scale)
+
+ if components.text_encoder_2 is not None:
+ if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(components.text_encoder_2, lora_scale)
+
+ return prompt_embeds, pooled_prompt_embeds
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ # Get inputs and intermediates
+ block_state = self.get_block_state(state)
+ self.check_inputs(block_state)
+
+ block_state.device = components._execution_device
+
+ # Encode input prompt
+ block_state.text_encoder_lora_scale = (
+ block_state.joint_attention_kwargs.get("scale", None)
+ if block_state.joint_attention_kwargs is not None
+ else None
+ )
+ block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt(
+ components,
+ prompt=block_state.prompt,
+ prompt_2=None,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ device=block_state.device,
+ max_sequence_length=block_state.max_sequence_length,
+ lora_scale=block_state.text_encoder_lora_scale,
+ )
+
+ # Add outputs
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py
new file mode 100644
index 000000000000..8309eebfeb37
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/inputs.py
@@ -0,0 +1,363 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+from ...pipelines import FluxPipeline
+from ...utils import logging
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import InputParam, OutputParam
+
+# TODO: consider making these common utilities for modular if they are not pipeline-specific.
+from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size
+from .modular_pipeline import FluxModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+class FluxTextInputStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Text input processing step that standardizes text embeddings for the pipeline.\n"
+ "This step:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_images_per_prompt", default=1),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
+ ),
+ # TODO: support negative embeddings?
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ # TODO: support negative embeddings?
+ ]
+
+ def check_inputs(self, components, block_state):
+ if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
+ if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
+ raise ValueError(
+ "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
+ f" {block_state.pooled_prompt_embeds.shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ # TODO: consider adding negative embeddings?
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+ pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
+ block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, -1
+ )
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# Adapted from `QwenImageInputsDynamicStep`
+class FluxInputsDynamicStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ # Functionality section
+ summary_section = (
+ "Input processing step that:\n"
+ " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size"
+ )
+
+ # Inputs info
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ # Add image latent inputs
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ # Add additional batch inputs
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
+ OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify the image latent tensor
+ # TODO: Implement patchifier for Flux.
+ latent_height, latent_width = image_latent_tensor.shape[2:]
+ image_latent_tensor = FluxPipeline._pack_latents(
+ image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
+ )
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
+ model_name = "flux-kontext"
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ # Unlike the `FluxInputsDynamicStep`, we don't overwrite the `block.height` and `block.width`
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify the image latent tensor
+ # TODO: Implement patchifier for Flux.
+ latent_height, latent_width = image_latent_tensor.shape[2:]
+ image_latent_tensor = FluxPipeline._pack_latents(
+ image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
+ )
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextSetResolutionStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ def description(self):
+ return (
+ "Determines the height and width to be used during the subsequent computations.\n"
+ "It should always be placed _before_ the latent preparation step."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="max_area", type_hint=int, default=1024**2),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="height", type_hint=int, description="The height of the initial noisy latents"),
+ OutputParam(name="width", type_hint=int, description="The width of the initial noisy latents"),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ self.check_inputs(height, width, components.vae_scale_factor)
+
+ original_height, original_width = height, width
+ max_area = block_state.max_area
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = components.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ block_state.height = height
+ block_state.width = width
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py
new file mode 100644
index 000000000000..a80bc2a5f7a9
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py
@@ -0,0 +1,446 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import (
+ FluxImg2ImgPrepareLatentsStep,
+ FluxImg2ImgSetTimestepsStep,
+ FluxKontextRoPEInputsStep,
+ FluxPrepareLatentsStep,
+ FluxRoPEInputsStep,
+ FluxSetTimestepsStep,
+)
+from .decoders import FluxDecodeStep
+from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep
+from .encoders import (
+ FluxKontextProcessImagesInputStep,
+ FluxProcessImagesInputStep,
+ FluxTextEncoderStep,
+ FluxVaeEncoderDynamicStep,
+)
+from .inputs import (
+ FluxInputsDynamicStep,
+ FluxKontextInputsDynamicStep,
+ FluxKontextSetResolutionStep,
+ FluxTextInputStep,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# vae encoder (run before before_denoise)
+FluxImg2ImgVaeEncoderBlocks = InsertableDict(
+ [("preprocess", FluxProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep())]
+)
+
+
+class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "flux"
+
+ block_classes = FluxImg2ImgVaeEncoderBlocks.values()
+ block_names = FluxImg2ImgVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [FluxImg2ImgVaeEncoderStep]
+ block_names = ["img2img"]
+ block_trigger_inputs = ["image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block that works for img2img tasks.\n"
+ + " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
+ + " - if `image` is not provided, step will be skipped."
+ )
+
+
+# Flux Kontext vae encoder (run before before_denoise)
+
+FluxKontextVaeEncoderBlocks = InsertableDict(
+ [("preprocess", FluxKontextProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax"))]
+)
+
+
+class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+
+ block_classes = FluxKontextVaeEncoderBlocks.values()
+ block_names = FluxKontextVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextVaeEncoderStep]
+ block_names = ["img2img"]
+ block_trigger_inputs = ["image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block that works for img2img tasks.\n"
+ + " - `FluxKontextVaeEncoderStep` (img2img) is used when only `image` is provided."
+ + " - if `image` is not provided, step will be skipped."
+ )
+
+
+# before_denoise: text2img
+FluxBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ]
+)
+
+
+class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = FluxBeforeDenoiseBlocks.values()
+ block_names = FluxBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
+
+
+# before_denoise: img2img
+FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
+ ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ]
+)
+
+
+class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
+ block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs for the denoise step for img2img task."
+
+
+# before_denoise: all task (text2img, img2img)
+class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ model_name = "flux-kontext"
+ block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2image.\n"
+ + " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ + " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
+ )
+
+
+# before_denoise: FluxKontext
+
+FluxKontextBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
+ ]
+)
+
+
+class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = FluxKontextBeforeDenoiseBlocks.values()
+ block_names = FluxKontextBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step\n"
+ "for img2img/text2img task for Flux Kontext."
+ )
+
+
+class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextBeforeDenoiseStep, FluxBeforeDenoiseStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2image.\n"
+ + " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ + " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
+ )
+
+
+# denoise: text2image
+class FluxAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxDenoiseStep]
+ block_names = ["denoise"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. "
+ "This is a auto pipeline block that works for text2image and img2img tasks."
+ " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
+ )
+
+
+# denoise: Flux Kontext
+
+
+class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextDenoiseStep]
+ block_names = ["denoise"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents for Flux Kontext. "
+ "This is a auto pipeline block that works for text2image and img2img tasks."
+ " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
+ )
+
+
+# decode: all task (text2img, img2img)
+class FluxAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [FluxDecodeStep]
+ block_names = ["non-inpaint"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self):
+ return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
+
+
+# inputs: text2image/img2img
+FluxImg2ImgBlocks = InsertableDict(
+ [("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())]
+)
+
+
+class FluxImg2ImgInputStep(SequentialPipelineBlocks):
+ model_name = "flux"
+ block_classes = FluxImg2ImgBlocks.values()
+ block_names = FluxImg2ImgBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+class FluxAutoInputStep(AutoPipelineBlocks):
+ block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ + " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
+ )
+
+
+# inputs: Flux Kontext
+
+FluxKontextBlocks = InsertableDict(
+ [
+ ("set_resolution", FluxKontextSetResolutionStep()),
+ ("text_inputs", FluxTextInputStep()),
+ ("additional_inputs", FluxKontextInputsDynamicStep()),
+ ]
+)
+
+
+class FluxKontextInputStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+ block_classes = FluxKontextBlocks.values()
+ block_names = FluxKontextBlocks.keys()
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+ )
+
+
+class FluxKontextAutoInputStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextInputStep, FluxTextInputStep]
+ # block_classes = [FluxKontextInputStep]
+ block_names = ["img2img", "text2img"]
+ # block_names = ["img2img"]
+ block_trigger_inputs = ["image_latents", None]
+ # block_trigger_inputs = ["image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ + " - `FluxKontextInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
+ )
+
+
+class FluxCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "flux"
+ block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings."
+ )
+
+
+class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+ block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings."
+ )
+
+
+# Auto blocks (text2image and img2img)
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("image_encoder", FluxAutoVaeEncoderStep()),
+ ("denoise", FluxCoreDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+AUTO_BLOCKS_KONTEXT = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("image_encoder", FluxKontextAutoVaeEncoderStep()),
+ ("denoise", FluxKontextCoreDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+
+class FluxAutoBlocks(SequentialPipelineBlocks):
+ model_name = "flux"
+
+ block_classes = AUTO_BLOCKS.values()
+ block_names = AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`\n"
+ + "- for image-to-image generation, you need to provide either `image` or `image_latents`"
+ )
+
+
+class FluxKontextAutoBlocks(FluxAutoBlocks):
+ model_name = "flux-kontext"
+
+ block_classes = AUTO_BLOCKS_KONTEXT.values()
+ block_names = AUTO_BLOCKS_KONTEXT.keys()
+
+
+TEXT2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("input", FluxTextInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ("denoise", FluxDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+IMAGE2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("vae_encoder", FluxVaeEncoderDynamicStep()),
+ ("input", FluxImg2ImgInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
+ ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ("denoise", FluxDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+FLUX_KONTEXT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("vae_encoder", FluxVaeEncoderDynamicStep(sample_mode="argmax")),
+ ("input", FluxKontextInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
+ ("denoise", FluxKontextDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+ALL_BLOCKS = {
+ "text2image": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "auto": AUTO_BLOCKS,
+ "auto_kontext": AUTO_BLOCKS_KONTEXT,
+ "kontext": FLUX_KONTEXT_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/flux/modular_pipeline.py b/src/diffusers/modular_pipelines/flux/modular_pipeline.py
new file mode 100644
index 000000000000..d8158f5d4fd6
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/modular_pipeline.py
@@ -0,0 +1,67 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
+from ...utils import logging
+from ..modular_pipeline import ModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
+ """
+ A ModularPipeline for Flux.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "FluxAutoBlocks"
+
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ return 128
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if getattr(self, "vae", None) is not None:
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ return vae_scale_factor
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if getattr(self, "transformer", None):
+ num_channels_latents = self.transformer.config.in_channels // 4
+ return num_channels_latents
+
+
+class FluxKontextModularPipeline(FluxModularPipeline):
+ """
+ A ModularPipeline for Flux Kontext.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "FluxKontextAutoBlocks"
diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py
new file mode 100644
index 000000000000..a405aebee221
--- /dev/null
+++ b/src/diffusers/modular_pipelines/mellon_node_utils.py
@@ -0,0 +1,763 @@
+import json
+import logging
+import os
+
+# Simple typed wrapper for parameter overrides
+from dataclasses import asdict, dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from huggingface_hub import create_repo, hf_hub_download
+from huggingface_hub.utils import (
+ EntryNotFoundError,
+ HfHubHTTPError,
+ RepositoryNotFoundError,
+ RevisionNotFoundError,
+ validate_hf_hub_args,
+)
+
+from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, PushToHubMixin, extract_commit_hash
+from .modular_pipeline import ModularPipelineBlocks
+
+
+logger = logging.getLogger(__name__)
+
+
+SUPPORTED_NODE_TYPES = {"controlnet", "vae_encoder", "denoise", "text_encoder", "decoder"}
+
+
+# Mellon Input Parameters (runtime parameters, not models)
+MELLON_INPUT_PARAMS = {
+ # controlnet
+ "control_image": {
+ "label": "Control Image",
+ "type": "image",
+ "display": "input",
+ },
+ "controlnet_conditioning_scale": {
+ "label": "Scale",
+ "type": "float",
+ "default": 0.5,
+ "min": 0,
+ "max": 1,
+ },
+ "control_guidance_end": {
+ "label": "End",
+ "type": "float",
+ "default": 1.0,
+ "min": 0,
+ "max": 1,
+ },
+ "control_guidance_start": {
+ "label": "Start",
+ "type": "float",
+ "default": 0.0,
+ "min": 0,
+ "max": 1,
+ },
+ "controlnet": {
+ "label": "Controlnet",
+ "type": "custom_controlnet",
+ "display": "input",
+ },
+ "embeddings": {
+ "label": "Text Embeddings",
+ "display": "input",
+ "type": "embeddings",
+ },
+ "image": {
+ "label": "Image",
+ "type": "image",
+ "display": "input",
+ },
+ "negative_prompt": {
+ "label": "Negative Prompt",
+ "type": "string",
+ "default": "",
+ "display": "textarea",
+ },
+ "prompt": {
+ "label": "Prompt",
+ "type": "string",
+ "default": "",
+ "display": "textarea",
+ },
+ "guidance_scale": {
+ "label": "Guidance Scale",
+ "type": "float",
+ "display": "slider",
+ "default": 5,
+ "min": 1.0,
+ "max": 30.0,
+ "step": 0.1,
+ },
+ "height": {
+ "label": "Height",
+ "type": "int",
+ "default": 1024,
+ "min": 64,
+ "step": 8,
+ },
+ "image_latents": {
+ "label": "Image Latents",
+ "type": "latents",
+ "display": "input",
+ "onChange": {False: ["height", "width"], True: ["strength"]},
+ },
+ "latents": {
+ "label": "Latents",
+ "type": "latents",
+ "display": "input",
+ },
+ "num_inference_steps": {
+ "label": "Steps",
+ "type": "int",
+ "display": "slider",
+ "default": 25,
+ "min": 1,
+ "max": 100,
+ },
+ "seed": {
+ "label": "Seed",
+ "type": "int",
+ "display": "random",
+ "default": 0,
+ "min": 0,
+ "max": 4294967295,
+ },
+ "strength": {
+ "label": "Strength",
+ "type": "float",
+ "default": 0.5,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ },
+ "width": {
+ "label": "Width",
+ "type": "int",
+ "default": 1024,
+ "min": 64,
+ "step": 8,
+ },
+ "ip_adapter": {
+ "label": "IP Adapter",
+ "type": "custom_ip_adapter",
+ "display": "input",
+ },
+}
+
+# Mellon Model Parameters (diffusers_auto_model types)
+MELLON_MODEL_PARAMS = {
+ "scheduler": {
+ "label": "Scheduler",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ },
+ "text_encoders": {
+ "label": "Text Encoders",
+ "type": "diffusers_auto_models",
+ "display": "input",
+ },
+ "unet": {
+ "label": "Unet",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ "onSignal": {
+ "action": "signal",
+ "target": "guider",
+ },
+ },
+ "guider": {
+ "label": "Guider",
+ "display": "input",
+ "type": "custom_guider",
+ "onChange": {False: ["guidance_scale"], True: []},
+ },
+ "vae": {
+ "label": "VAE",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ },
+ "controlnet": {
+ "label": "Controlnet Model",
+ "type": "diffusers_auto_model",
+ "display": "input",
+ },
+}
+
+# Mellon Output Parameters (display = "output")
+MELLON_OUTPUT_PARAMS = {
+ "embeddings": {
+ "label": "Text Embeddings",
+ "display": "output",
+ "type": "embeddings",
+ },
+ "images": {
+ "label": "Images",
+ "type": "image",
+ "display": "output",
+ },
+ "image_latents": {
+ "label": "Image Latents",
+ "type": "latents",
+ "display": "output",
+ },
+ "latents": {
+ "label": "Latents",
+ "type": "latents",
+ "display": "output",
+ },
+ "latents_preview": {
+ "label": "Latents Preview",
+ "display": "output",
+ "type": "latent",
+ },
+ "controlnet_out": {
+ "label": "Controlnet",
+ "display": "output",
+ "type": "controlnet",
+ },
+}
+
+
+# Default param selections per supported node_type
+# from MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS.
+NODE_TYPE_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ "vae",
+ ],
+ "outputs": [
+ "controlnet",
+ ],
+ "block_names": ["controlnet_vae_encoder"],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ # custom adapters coming in as inputs
+ "controlnet",
+ # ip_adapter is optional and custom; include if available
+ "ip_adapter",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ "block_names": ["vae_encoder"],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ # optional image prompt input supported in embeddings node
+ "image",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ "block_names": ["text_encoder"],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ "block_names": ["decode"],
+ },
+}
+
+
+@dataclass(frozen=True)
+class MellonParam:
+ name: str
+ label: str
+ type: str
+ display: Optional[str] = None
+ default: Any = None
+ min: Optional[float] = None
+ max: Optional[float] = None
+ step: Optional[float] = None
+ options: Any = None
+ value: Any = None
+ fieldOptions: Optional[Dict[str, Any]] = None
+ onChange: Any = None
+ onSignal: Any = None
+ _map_to_input: Any = None # the block input name this parameter maps to
+
+ def to_dict(self) -> Dict[str, Any]:
+ data = asdict(self)
+ return {k: v for k, v in data.items() if not k.startswith("_") and v is not None}
+
+
+@dataclass
+class MellonNodeConfig(PushToHubMixin):
+ """
+ A MellonNodeConfig is a base class to build Mellon nodes UI with modular diffusers.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ inputs: List[Union[str, MellonParam]]
+ model_inputs: List[Union[str, MellonParam]]
+ outputs: List[Union[str, MellonParam]]
+ blocks_names: list[str]
+ node_type: str
+ config_name = "mellon_config.json"
+
+ def __post_init__(self):
+ if isinstance(self.inputs, list):
+ self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS)
+ if isinstance(self.model_inputs, list):
+ self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS)
+ if isinstance(self.outputs, list):
+ self.outputs = self._resolve_params_list(self.outputs, MELLON_OUTPUT_PARAMS)
+
+ @staticmethod
+ def _resolve_params_list(
+ params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]]
+ ) -> Dict[str, Dict[str, Any]]:
+ def _resolve_param(
+ param: Union[str, MellonParam], default_params_map: Dict[str, Dict[str, Any]]
+ ) -> Tuple[str, Dict[str, Any]]:
+ if isinstance(param, str):
+ if param not in default_params_map:
+ raise ValueError(f"Unknown param '{param}', please define a `MellonParam` object instead")
+ return param, default_params_map[param].copy()
+ elif isinstance(param, MellonParam):
+ param_dict = param.to_dict()
+ param_name = param_dict.pop("name")
+ return param_name, param_dict
+ else:
+ raise ValueError(
+ f"Unknown param type '{type(param)}', please use a string or a `MellonParam` object instead"
+ )
+
+ resolved = {}
+ for p in params:
+ logger.info(f" Resolving param: {p}")
+ name, cfg = _resolve_param(p, default_map)
+ if name in resolved:
+ raise ValueError(f"Duplicate param '{name}'")
+ resolved[name] = cfg
+ return resolved
+
+ @classmethod
+ @validate_hf_hub_args
+ def load_mellon_config(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ return_unused_kwargs=False,
+ return_commit_hash=False,
+ **kwargs,
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Load a model or scheduler configuration.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
+ [`~ConfigMixin.save_config`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
+ Whether unused keyword arguments of the config are returned.
+ return_commit_hash (`bool`, *optional*, defaults to `False):
+ Whether the `commit_hash` of the loaded configuration are returned.
+
+ Returns:
+ `dict`:
+ A dictionary of all the parameters stored in a JSON configuration file.
+
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ local_dir = kwargs.pop("local_dir", None)
+ local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ local_dir=local_dir,
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ )
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `token` or log in with `hf auth login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HfHubHTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+ try:
+ with open(config_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ config_dict = json.loads(text)
+
+ commit_hash = extract_commit_hash(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ if not (return_unused_kwargs or return_commit_hash):
+ return config_dict
+
+ outputs = (config_dict,)
+
+ if return_unused_kwargs:
+ outputs += (kwargs,)
+
+ if return_commit_hash:
+ outputs += (commit_hash,)
+
+ return outputs
+
+ def save_mellon_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save the Mellon node definition to a JSON file.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"Mellon node definition saved in {output_config_file}")
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ private = kwargs.pop("private", None)
+ create_pr = kwargs.pop("create_pr", False)
+ token = kwargs.pop("token", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
+ subfolder = kwargs.pop("subfolder", None)
+
+ self._upload_folder(
+ save_directory,
+ repo_id,
+ token=token,
+ commit_message=commit_message,
+ create_pr=create_pr,
+ subfolder=subfolder,
+ )
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save the Mellon schema dictionary to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file to save a configuration instance's parameters.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string of the Mellon schema dict.
+
+ Args:
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+
+ mellon_dict = self.to_mellon_dict()
+ return json.dumps(mellon_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_mellon_dict(self) -> Dict[str, Any]:
+ """Return a JSON-serializable dict focusing on the Mellon schema fields only.
+
+ params is a single flat dict composed as: {**inputs, **model_inputs, **outputs}.
+ """
+ # inputs/model_inputs/outputs are already normalized dicts
+ merged_params = {}
+ merged_params.update(self.inputs or {})
+ merged_params.update(self.model_inputs or {})
+ merged_params.update(self.outputs or {})
+
+ return {
+ "node_type": self.node_type,
+ "blocks_names": self.blocks_names,
+ "params": merged_params,
+ }
+
+ @classmethod
+ def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig":
+ """Create a config from a Mellon schema dict produced by to_mellon_dict().
+
+ Splits the flat params dict back into inputs/model_inputs/outputs using the known key spaces from
+ MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. Unknown keys are treated as inputs by
+ default.
+ """
+ flat_params = mellon_dict.get("params", {})
+
+ inputs: Dict[str, Any] = {}
+ model_inputs: Dict[str, Any] = {}
+ outputs: Dict[str, Any] = {}
+
+ for param_name, param_dict in flat_params.items():
+ if param_dict.get("display", "") == "output":
+ outputs[param_name] = param_dict
+ elif param_dict.get("type", "") in ("diffusers_auto_model", "diffusers_auto_models"):
+ model_inputs[param_name] = param_dict
+ else:
+ inputs[param_name] = param_dict
+
+ return cls(
+ inputs=inputs,
+ model_inputs=model_inputs,
+ outputs=outputs,
+ blocks_names=mellon_dict.get("blocks_names", []),
+ node_type=mellon_dict.get("node_type"),
+ )
+
+ # YiYi Notes: not used yet
+ @classmethod
+ def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNodeConfig":
+ """
+ Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type,
+ use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs.
+ """
+ if node_type not in NODE_TYPE_PARAMS_MAP:
+ raise ValueError(f"Node type {node_type} not supported")
+
+ blocks_names = list(blocks.sub_blocks.keys())
+
+ default_node_config = NODE_TYPE_PARAMS_MAP[node_type]
+ inputs_list: List[Union[str, MellonParam]] = default_node_config.get("inputs", [])
+ model_inputs_list: List[Union[str, MellonParam]] = default_node_config.get("model_inputs", [])
+ outputs_list: List[Union[str, MellonParam]] = default_node_config.get("outputs", [])
+
+ for required_input_name in blocks.required_inputs:
+ if required_input_name not in inputs_list:
+ inputs_list.append(
+ MellonParam(
+ name=required_input_name, label=required_input_name, type=required_input_name, display="input"
+ )
+ )
+
+ for component_spec in blocks.expected_components:
+ if component_spec.name not in model_inputs_list:
+ model_inputs_list.append(
+ MellonParam(
+ name=component_spec.name,
+ label=component_spec.name,
+ type="diffusers_auto_model",
+ display="input",
+ )
+ )
+
+ return cls(
+ inputs=inputs_list,
+ model_inputs=model_inputs_list,
+ outputs=outputs_list,
+ blocks_names=blocks_names,
+ node_type=node_type,
+ )
+
+
+# Minimal modular registry for Mellon node configs
+class ModularMellonNodeRegistry:
+ """Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig."""
+
+ def __init__(self):
+ self._registry = {}
+ self._initialized = False
+
+ def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]):
+ if not self._initialized:
+ _initialize_registry(self)
+ self._registry[pipeline_cls] = node_params
+
+ def get(self, pipeline_cls: type) -> MellonNodeConfig:
+ if not self._initialized:
+ _initialize_registry(self)
+ return self._registry.get(pipeline_cls, None)
+
+ def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]:
+ if not self._initialized:
+ _initialize_registry(self)
+ return self._registry
+
+
+def _register_preset_node_types(
+ pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry
+):
+ """Register all node-type presets for a given pipeline class from a params map."""
+ node_configs = {}
+ for node_type, spec in params_map.items():
+ node_config = MellonNodeConfig(
+ inputs=spec.get("inputs", []),
+ model_inputs=spec.get("model_inputs", []),
+ outputs=spec.get("outputs", []),
+ blocks_names=spec.get("block_names", []),
+ node_type=node_type,
+ )
+ node_configs[node_type] = node_config
+ registry.register(pipeline_cls, node_configs)
+
+
+def _initialize_registry(registry: ModularMellonNodeRegistry):
+ """Initialize the registry and register all available pipeline configs."""
+ print("Initializing registry")
+
+ registry._initialized = True
+
+ try:
+ from .qwenimage.modular_pipeline import QwenImageModularPipeline
+ from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP
+
+ _register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry)
+ except Exception:
+ raise Exception("Failed to register QwenImageModularPipeline")
+
+ try:
+ from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline
+ from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP
+
+ _register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry)
+ except Exception:
+ raise Exception("Failed to register StableDiffusionXLModularPipeline")
diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py
new file mode 100644
index 000000000000..a6336de71a52
--- /dev/null
+++ b/src/diffusers/modular_pipelines/modular_pipeline.py
@@ -0,0 +1,2576 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import inspect
+import os
+import traceback
+import warnings
+from collections import OrderedDict
+from copy import deepcopy
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from huggingface_hub import create_repo
+from huggingface_hub.utils import validate_hf_hub_args
+from tqdm.auto import tqdm
+from typing_extensions import Self
+
+from ..configuration_utils import ConfigMixin, FrozenDict
+from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
+from ..utils import PushToHubMixin, is_accelerate_available, logging
+from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ..utils.hub_utils import load_or_create_model_card, populate_model_card
+from .components_manager import ComponentsManager
+from .modular_pipeline_utils import (
+ ComponentSpec,
+ ConfigSpec,
+ InputParam,
+ InsertableDict,
+ OutputParam,
+ format_components,
+ format_configs,
+ make_doc_string,
+)
+
+
+if is_accelerate_available():
+ import accelerate
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# map regular pipeline to modular pipeline class name
+MODULAR_PIPELINE_MAPPING = OrderedDict(
+ [
+ ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
+ ("wan", "WanModularPipeline"),
+ ("flux", "FluxModularPipeline"),
+ ("flux-kontext", "FluxKontextModularPipeline"),
+ ("qwenimage", "QwenImageModularPipeline"),
+ ("qwenimage-edit", "QwenImageEditModularPipeline"),
+ ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
+ ]
+)
+
+
+@dataclass
+class PipelineState:
+ """
+ [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks.
+ """
+
+ values: Dict[str, Any] = field(default_factory=dict)
+ kwargs_mapping: Dict[str, List[str]] = field(default_factory=dict)
+
+ def set(self, key: str, value: Any, kwargs_type: str = None):
+ """
+ Add a value to the pipeline state.
+
+ Args:
+ key (str): The key for the value
+ value (Any): The value to store
+ kwargs_type (str): The kwargs_type with which the value is associated
+ """
+ self.values[key] = value
+
+ if kwargs_type is not None:
+ if kwargs_type not in self.kwargs_mapping:
+ self.kwargs_mapping[kwargs_type] = [key]
+ else:
+ self.kwargs_mapping[kwargs_type].append(key)
+
+ def get(self, keys: Union[str, List[str]], default: Any = None) -> Union[Any, Dict[str, Any]]:
+ """
+ Get one or multiple values from the pipeline state.
+
+ Args:
+ keys (Union[str, List[str]]): Key or list of keys for the values
+ default (Any): The default value to return if not found
+
+ Returns:
+ Union[Any, Dict[str, Any]]: Single value if keys is str, dictionary of values if keys is list
+ """
+ if isinstance(keys, str):
+ return self.values.get(keys, default)
+ return {key: self.values.get(key, default) for key in keys}
+
+ def get_by_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
+ """
+ Get all values with matching kwargs_type.
+
+ Args:
+ kwargs_type (str): The kwargs_type to filter by
+
+ Returns:
+ Dict[str, Any]: Dictionary of values with matching kwargs_type
+ """
+ value_names = self.kwargs_mapping.get(kwargs_type, [])
+ return self.get(value_names)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Convert PipelineState to a dictionary.
+ """
+ return {**self.__dict__}
+
+ def __getattr__(self, name):
+ """
+ Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
+ intermediates dict.
+ """
+ # Use object.__getattribute__ to avoid infinite recursion during deepcopy
+ try:
+ values = object.__getattribute__(self, "values")
+ except AttributeError:
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
+
+ if name in values:
+ return values[name]
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
+
+ def __repr__(self):
+ def format_value(v):
+ if hasattr(v, "shape") and hasattr(v, "dtype"):
+ return f"Tensor(dtype={v.dtype}, shape={v.shape})"
+ elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"):
+ return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]"
+ else:
+ return repr(v)
+
+ values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items())
+ kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items())
+
+ return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)"
+
+
+@dataclass
+class BlockState:
+ """
+ Container for block state data with attribute access and formatted representation.
+ """
+
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ def __getitem__(self, key: str):
+ # allows block_state["foo"]
+ return getattr(self, key, None)
+
+ def __setitem__(self, key: str, value: Any):
+ # allows block_state["foo"] = "bar"
+ setattr(self, key, value)
+
+ def as_dict(self):
+ """
+ Convert BlockState to a dictionary.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing all attributes of the BlockState
+ """
+ return dict(self.__dict__.items())
+
+ def __repr__(self):
+ def format_value(v):
+ # Handle tensors directly
+ if hasattr(v, "shape") and hasattr(v, "dtype"):
+ return f"Tensor(dtype={v.dtype}, shape={v.shape})"
+
+ # Handle lists of tensors
+ elif isinstance(v, list):
+ if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"):
+ shapes = [t.shape for t in v]
+ return f"List[{len(v)}] of Tensors with shapes {shapes}"
+ return repr(v)
+
+ # Handle tuples of tensors
+ elif isinstance(v, tuple):
+ if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"):
+ shapes = [t.shape for t in v]
+ return f"Tuple[{len(v)}] of Tensors with shapes {shapes}"
+ return repr(v)
+
+ # Handle dicts with tensor values
+ elif isinstance(v, dict):
+ formatted_dict = {}
+ for k, val in v.items():
+ if hasattr(val, "shape") and hasattr(val, "dtype"):
+ formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})"
+ elif (
+ isinstance(val, list)
+ and len(val) > 0
+ and hasattr(val[0], "shape")
+ and hasattr(val[0], "dtype")
+ ):
+ shapes = [t.shape for t in val]
+ formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}"
+ else:
+ formatted_dict[k] = repr(val)
+ return formatted_dict
+
+ # Default case
+ return repr(v)
+
+ attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items())
+ return f"BlockState(\n{attributes}\n)"
+
+
+class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
+ """
+ Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
+ LoopSequentialPipelineBlocks
+
+ [`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ config_name = "modular_config.json"
+ model_name = None
+
+ @classmethod
+ def _get_signature_keys(cls, obj):
+ parameters = inspect.signature(obj.__init__).parameters
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
+ expected_modules = set(required_parameters.keys()) - {"self"}
+
+ return expected_modules, optional_parameters
+
+ def __init__(self):
+ self.sub_blocks = InsertableDict()
+
+ @property
+ def description(self) -> str:
+ """Description of the block. Must be implemented by subclasses."""
+ return ""
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return []
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return []
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ """List of input parameters. Must be implemented by subclasses."""
+ return []
+
+ def _get_required_inputs(self):
+ input_names = []
+ for input_param in self.inputs:
+ if input_param.required:
+ input_names.append(input_param.name)
+
+ return input_names
+
+ @property
+ def required_inputs(self) -> List[InputParam]:
+ return self._get_required_inputs()
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ """List of intermediate output parameters. Must be implemented by subclasses."""
+ return []
+
+ def _get_outputs(self):
+ return self.intermediate_outputs
+
+ @property
+ def outputs(self) -> List[OutputParam]:
+ return self._get_outputs()
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ trust_remote_code: bool = False,
+ **kwargs,
+ ):
+ hub_kwargs_names = [
+ "cache_dir",
+ "force_download",
+ "local_files_only",
+ "local_dir",
+ "proxies",
+ "revision",
+ "subfolder",
+ "token",
+ ]
+ hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
+
+ config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
+ has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, pretrained_model_name_or_path, has_remote_code
+ )
+ if not has_remote_code and trust_remote_code:
+ raise ValueError(
+ "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
+ )
+
+ class_ref = config["auto_map"][cls.__name__]
+ module_file, class_name = class_ref.split(".")
+ module_file = module_file + ".py"
+ block_cls = get_class_from_dynamic_module(
+ pretrained_model_name_or_path,
+ module_file=module_file,
+ class_name=class_name,
+ **hub_kwargs,
+ )
+ expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
+ block_kwargs = {
+ name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
+ }
+
+ return block_cls(**block_kwargs)
+
+ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
+ # TODO: factor out this logic.
+ cls_name = self.__class__.__name__
+
+ full_mod = type(self).__module__
+ module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
+ parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
+ auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
+
+ self.register_to_config(auto_map=auto_map)
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+ config = dict(self.config)
+ self._internal_dict = FrozenDict(config)
+
+ def init_pipeline(
+ self,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ components_manager: Optional[ComponentsManager] = None,
+ collection: Optional[str] = None,
+ ) -> "ModularPipeline":
+ """
+ create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
+ """
+ pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
+ diffusers_module = importlib.import_module("diffusers")
+ pipeline_class = getattr(diffusers_module, pipeline_class_name)
+
+ modular_pipeline = pipeline_class(
+ blocks=deepcopy(self),
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ components_manager=components_manager,
+ collection=collection,
+ )
+ return modular_pipeline
+
+ def get_block_state(self, state: PipelineState) -> dict:
+ """Get all inputs and intermediates in one dictionary"""
+ data = {}
+ state_inputs = self.inputs
+
+ # Check inputs
+ for input_param in state_inputs:
+ if input_param.name:
+ value = state.get(input_param.name)
+ if input_param.required and value is None:
+ raise ValueError(f"Required input '{input_param.name}' is missing")
+ elif value is not None or (value is None and input_param.name not in data):
+ data[input_param.name] = value
+
+ elif input_param.kwargs_type:
+ # if kwargs_type is provided, get all inputs with matching kwargs_type
+ if input_param.kwargs_type not in data:
+ data[input_param.kwargs_type] = {}
+ inputs_kwargs = state.get_by_kwargs(input_param.kwargs_type)
+ if inputs_kwargs:
+ for k, v in inputs_kwargs.items():
+ if v is not None:
+ data[k] = v
+ data[input_param.kwargs_type][k] = v
+
+ return BlockState(**data)
+
+ def set_block_state(self, state: PipelineState, block_state: BlockState):
+ for output_param in self.intermediate_outputs:
+ if not hasattr(block_state, output_param.name):
+ raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
+ param = getattr(block_state, output_param.name)
+ state.set(output_param.name, param, output_param.kwargs_type)
+
+ for input_param in self.inputs:
+ if input_param.name and hasattr(block_state, input_param.name):
+ param = getattr(block_state, input_param.name)
+ # Only add if the value is different from what's in the state
+ current_value = state.get(input_param.name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set(input_param.name, param, input_param.kwargs_type)
+
+ elif input_param.kwargs_type:
+ # if it is a kwargs type, e.g. "denoiser_input_fields", it is likely to be a list of parameters
+ # we need to first find out which inputs are and loop through them.
+ intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
+ for param_name, current_value in intermediate_kwargs.items():
+ if param_name is None:
+ continue
+
+ if not hasattr(block_state, param_name):
+ continue
+
+ param = getattr(block_state, param_name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set(param_name, param, input_param.kwargs_type)
+
+ @staticmethod
+ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
+ """
+ Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
+ current default value is None and new default value is not None. Warns if multiple non-None default values
+ exist for the same input.
+
+ Args:
+ named_input_lists: List of tuples containing (block_name, input_param_list) pairs
+
+ Returns:
+ List[InputParam]: Combined list of unique InputParam objects
+ """
+ combined_dict = {} # name -> InputParam
+ value_sources = {} # name -> block_name
+
+ for block_name, inputs in named_input_lists:
+ for input_param in inputs:
+ if input_param.name is None and input_param.kwargs_type is not None:
+ input_name = "*_" + input_param.kwargs_type
+ else:
+ input_name = input_param.name
+ if input_name in combined_dict:
+ current_param = combined_dict[input_name]
+ if (
+ current_param.default is not None
+ and input_param.default is not None
+ and current_param.default != input_param.default
+ ):
+ warnings.warn(
+ f"Multiple different default values found for input '{input_name}': "
+ f"{current_param.default} (from block '{value_sources[input_name]}') and "
+ f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
+ )
+ if current_param.default is None and input_param.default is not None:
+ combined_dict[input_name] = input_param
+ value_sources[input_name] = block_name
+ else:
+ combined_dict[input_name] = input_param
+ value_sources[input_name] = block_name
+
+ return list(combined_dict.values())
+
+ @staticmethod
+ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
+ """
+ Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
+ occurrence of each output name.
+
+ Args:
+ named_output_lists: List of tuples containing (block_name, output_param_list) pairs
+
+ Returns:
+ List[OutputParam]: Combined list of unique OutputParam objects
+ """
+ combined_dict = {} # name -> OutputParam
+
+ for block_name, outputs in named_output_lists:
+ for output_param in outputs:
+ if (output_param.name not in combined_dict) or (
+ combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
+ ):
+ combined_dict[output_param.name] = output_param
+
+ return list(combined_dict.values())
+
+ @property
+ def input_names(self) -> List[str]:
+ return [input_param.name for input_param in self.inputs]
+
+ @property
+ def intermediate_output_names(self) -> List[str]:
+ return [output_param.name for output_param in self.intermediate_outputs]
+
+ @property
+ def output_names(self) -> List[str]:
+ return [output_param.name for output_param in self.outputs]
+
+ @property
+ def doc(self):
+ return make_doc_string(
+ self.inputs,
+ self.outputs,
+ self.description,
+ class_name=self.__class__.__name__,
+ expected_components=self.expected_components,
+ expected_configs=self.expected_configs,
+ )
+
+
+class AutoPipelineBlocks(ModularPipelineBlocks):
+ """
+ A Pipeline Blocks that automatically selects a block to run based on the inputs.
+
+ This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipeline blocks (such as loading or saving etc.)
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+
+ Attributes:
+ block_classes: List of block classes to be used
+ block_names: List of prefixes for each block
+ block_trigger_inputs: List of input names that trigger specific blocks, with None for default
+ """
+
+ block_classes = []
+ block_names = []
+ block_trigger_inputs = []
+
+ def __init__(self):
+ sub_blocks = InsertableDict()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
+ self.sub_blocks = sub_blocks
+ if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
+ raise ValueError(
+ f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
+ )
+ default_blocks = [t for t in self.block_trigger_inputs if t is None]
+ # can only have 1 or 0 default block, and has to put in the last
+ # the order of blocks matters here because the first block with matching trigger will be dispatched
+ # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
+ # as long as mask is provided, it is inpaint; if only image is provided, it is img2img
+ if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
+ raise ValueError(
+ f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
+ "in block_trigger_inputs."
+ )
+
+ # Map trigger inputs to block objects
+ self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values()))
+ self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys()))
+ self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs))
+
+ @property
+ def model_name(self):
+ return next(iter(self.sub_blocks.values())).model_name
+
+ @property
+ def description(self):
+ return ""
+
+ @property
+ def expected_components(self):
+ expected_components = []
+ for block in self.sub_blocks.values():
+ for component in block.expected_components:
+ if component not in expected_components:
+ expected_components.append(component)
+ return expected_components
+
+ @property
+ def expected_configs(self):
+ expected_configs = []
+ for block in self.sub_blocks.values():
+ for config in block.expected_configs:
+ if config not in expected_configs:
+ expected_configs.append(config)
+ return expected_configs
+
+ @property
+ def required_inputs(self) -> List[str]:
+ if None not in self.block_trigger_inputs:
+ return []
+ first_block = next(iter(self.sub_blocks.values()))
+ required_by_all = set(getattr(first_block, "required_inputs", set()))
+
+ # Intersect with required inputs from all other blocks
+ for block in list(self.sub_blocks.values())[1:]:
+ block_required = set(getattr(block, "required_inputs", set()))
+ required_by_all.intersection_update(block_required)
+
+ return list(required_by_all)
+
+ # YiYi TODO: add test for this
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
+ combined_inputs = self.combine_inputs(*named_inputs)
+ # mark Required inputs only if that input is required by all the blocks
+ for input_param in combined_inputs:
+ if input_param.name in self.required_inputs:
+ input_param.required = True
+ else:
+ input_param.required = False
+ return combined_inputs
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
+ combined_outputs = self.combine_outputs(*named_outputs)
+ return combined_outputs
+
+ @property
+ def outputs(self) -> List[str]:
+ named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
+ combined_outputs = self.combine_outputs(*named_outputs)
+ return combined_outputs
+
+ @torch.no_grad()
+ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
+ # Find default block first (if any)
+
+ block = self.trigger_to_block_map.get(None)
+ for input_name in self.block_trigger_inputs:
+ if input_name is not None and state.get(input_name) is not None:
+ block = self.trigger_to_block_map[input_name]
+ break
+
+ if block is None:
+ logger.info(f"skipping auto block: {self.__class__.__name__}")
+ return pipeline, state
+
+ try:
+ logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
+ return block(pipeline, state)
+ except Exception as e:
+ error_msg = (
+ f"\nError in block: {block.__class__.__name__}\n"
+ f"Error details: {str(e)}\n"
+ f"Traceback:\n{traceback.format_exc()}"
+ )
+ logger.error(error_msg)
+ raise
+
+ def _get_trigger_inputs(self):
+ """
+ Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
+ block_trigger_inputs values
+ """
+
+ def fn_recursive_get_trigger(blocks):
+ trigger_values = set()
+
+ if blocks is not None:
+ for name, block in blocks.items():
+ # Check if current block has trigger inputs(i.e. auto block)
+ if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
+ # Add all non-None values from the trigger inputs list
+ trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
+
+ # If block has sub_blocks, recursively check them
+ if block.sub_blocks:
+ nested_triggers = fn_recursive_get_trigger(block.sub_blocks)
+ trigger_values.update(nested_triggers)
+
+ return trigger_values
+
+ trigger_inputs = set(self.block_trigger_inputs)
+ trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks))
+
+ return trigger_inputs
+
+ @property
+ def trigger_inputs(self):
+ return self._get_trigger_inputs()
+
+ def __repr__(self):
+ class_name = self.__class__.__name__
+ base_class = self.__class__.__bases__[0].__name__
+ header = (
+ f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
+ )
+
+ if self.trigger_inputs:
+ header += "\n"
+ header += " " + "=" * 100 + "\n"
+ header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
+ header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
+ header += " " + "=" * 100 + "\n\n"
+
+ # Format description with proper indentation
+ desc_lines = self.description.split("\n")
+ desc = []
+ # First line with "Description:" label
+ desc.append(f" Description: {desc_lines[0]}")
+ # Subsequent lines with proper indentation
+ if len(desc_lines) > 1:
+ desc.extend(f" {line}" for line in desc_lines[1:])
+ desc = "\n".join(desc) + "\n"
+
+ # Components section - focus only on expected components
+ expected_components = getattr(self, "expected_components", [])
+ components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
+
+ # Configs section - use format_configs with add_empty_lines=False
+ expected_configs = getattr(self, "expected_configs", [])
+ configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
+
+ # Blocks section - moved to the end with simplified format
+ blocks_str = " Sub-Blocks:\n"
+ for i, (name, block) in enumerate(self.sub_blocks.items()):
+ # Get trigger input for this block
+ trigger = None
+ if hasattr(self, "block_to_trigger_map"):
+ trigger = self.block_to_trigger_map.get(name)
+ # Format the trigger info
+ if trigger is None:
+ trigger_str = "[default]"
+ elif isinstance(trigger, (list, tuple)):
+ trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
+ else:
+ trigger_str = f"[trigger: {trigger}]"
+ # For AutoPipelineBlocks, add bullet points
+ blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
+ else:
+ # For SequentialPipelineBlocks, show execution order
+ blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
+
+ # Add block description
+ desc_lines = block.description.split("\n")
+ indented_desc = desc_lines[0]
+ if len(desc_lines) > 1:
+ indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
+ blocks_str += f" Description: {indented_desc}\n\n"
+
+ # Build the representation with conditional sections
+ result = f"{header}\n{desc}"
+
+ # Only add components section if it has content
+ if components_str.strip():
+ result += f"\n\n{components_str}"
+
+ # Only add configs section if it has content
+ if configs_str.strip():
+ result += f"\n\n{configs_str}"
+
+ # Always add blocks section
+ result += f"\n\n{blocks_str})"
+
+ return result
+
+ @property
+ def doc(self):
+ return make_doc_string(
+ self.inputs,
+ self.outputs,
+ self.description,
+ class_name=self.__class__.__name__,
+ expected_components=self.expected_components,
+ expected_configs=self.expected_configs,
+ )
+
+
+class SequentialPipelineBlocks(ModularPipelineBlocks):
+ """
+ A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
+ sequence.
+
+ This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipeline blocks (such as loading or saving etc.)
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+
+ Attributes:
+ block_classes: List of block classes to be used
+ block_names: List of prefixes for each block
+ """
+
+ block_classes = []
+ block_names = []
+
+ @property
+ def description(self):
+ return ""
+
+ @property
+ def model_name(self):
+ return next((block.model_name for block in self.sub_blocks.values() if block.model_name is not None), None)
+
+ @property
+ def expected_components(self):
+ expected_components = []
+ for block in self.sub_blocks.values():
+ for component in block.expected_components:
+ if component not in expected_components:
+ expected_components.append(component)
+ return expected_components
+
+ @property
+ def expected_configs(self):
+ expected_configs = []
+ for block in self.sub_blocks.values():
+ for config in block.expected_configs:
+ if config not in expected_configs:
+ expected_configs.append(config)
+ return expected_configs
+
+ @classmethod
+ def from_blocks_dict(
+ cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
+ ) -> "SequentialPipelineBlocks":
+ """Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
+
+ Args:
+ blocks_dict: Dictionary mapping block names to block classes or instances
+
+ Returns:
+ A new SequentialPipelineBlocks instance
+ """
+ instance = cls()
+
+ # Create instances if classes are provided
+ sub_blocks = InsertableDict()
+ for name, block in blocks_dict.items():
+ if inspect.isclass(block):
+ sub_blocks[name] = block()
+ else:
+ sub_blocks[name] = block
+
+ instance.block_classes = [block.__class__ for block in sub_blocks.values()]
+ instance.block_names = list(sub_blocks.keys())
+ instance.sub_blocks = sub_blocks
+
+ if description is not None:
+ instance.description = description
+
+ return instance
+
+ def __init__(self):
+ sub_blocks = InsertableDict()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
+ self.sub_blocks = sub_blocks
+ if not len(self.block_names) == len(self.block_classes):
+ raise ValueError(
+ f"In {self.__class__.__name__}, the number of block_names and block_classes must be the same."
+ )
+
+ def _get_inputs(self):
+ inputs = []
+ outputs = set()
+
+ # Go through all blocks in order
+ for block in self.sub_blocks.values():
+ # Add inputs that aren't in outputs yet
+ for inp in block.inputs:
+ if inp.name not in outputs and inp.name not in {input.name for input in inputs}:
+ inputs.append(inp)
+
+ # Only add outputs if the block cannot be skipped
+ should_add_outputs = True
+ if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
+ should_add_outputs = False
+
+ if should_add_outputs:
+ # Add this block's outputs
+ block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
+ outputs.update(block_intermediate_outputs)
+
+ return inputs
+
+ # YiYi TODO: add test for this
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return self._get_inputs()
+
+ @property
+ def required_inputs(self) -> List[str]:
+ # Get the first block from the dictionary
+ first_block = next(iter(self.sub_blocks.values()))
+ required_by_any = set(getattr(first_block, "required_inputs", set()))
+
+ # Union with required inputs from all other blocks
+ for block in list(self.sub_blocks.values())[1:]:
+ block_required = set(getattr(block, "required_inputs", set()))
+ required_by_any.update(block_required)
+
+ return list(required_by_any)
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ named_outputs = []
+ for name, block in self.sub_blocks.items():
+ inp_names = {inp.name for inp in block.inputs}
+ # so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
+ # filter out them here so they do not end up as intermediate_outputs
+ if name not in inp_names:
+ named_outputs.append((name, block.intermediate_outputs))
+ combined_outputs = self.combine_outputs(*named_outputs)
+ return combined_outputs
+
+ # YiYi TODO: I think we can remove the outputs property
+ @property
+ def outputs(self) -> List[str]:
+ # return next(reversed(self.sub_blocks.values())).intermediate_outputs
+ return self.intermediate_outputs
+
+ @torch.no_grad()
+ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
+ for block_name, block in self.sub_blocks.items():
+ try:
+ pipeline, state = block(pipeline, state)
+ except Exception as e:
+ error_msg = (
+ f"\nError in block: ({block_name}, {block.__class__.__name__})\n"
+ f"Error details: {str(e)}\n"
+ f"Traceback:\n{traceback.format_exc()}"
+ )
+ logger.error(error_msg)
+ raise
+ return pipeline, state
+
+ def _get_trigger_inputs(self):
+ """
+ Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
+ block_trigger_inputs values
+ """
+
+ def fn_recursive_get_trigger(blocks):
+ trigger_values = set()
+
+ if blocks is not None:
+ for name, block in blocks.items():
+ # Check if current block has trigger inputs(i.e. auto block)
+ if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
+ # Add all non-None values from the trigger inputs list
+ trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
+
+ # If block has sub_blocks, recursively check them
+ if block.sub_blocks:
+ nested_triggers = fn_recursive_get_trigger(block.sub_blocks)
+ trigger_values.update(nested_triggers)
+
+ return trigger_values
+
+ return fn_recursive_get_trigger(self.sub_blocks)
+
+ @property
+ def trigger_inputs(self):
+ return self._get_trigger_inputs()
+
+ def _traverse_trigger_blocks(self, trigger_inputs):
+ # Convert trigger_inputs to a set for easier manipulation
+ active_triggers = set(trigger_inputs)
+
+ def fn_recursive_traverse(block, block_name, active_triggers):
+ result_blocks = OrderedDict()
+
+ # sequential(include loopsequential) or PipelineBlock
+ if not hasattr(block, "block_trigger_inputs"):
+ if block.sub_blocks:
+ # sequential or LoopSequentialPipelineBlocks (keep traversing)
+ for sub_block_name, sub_block in block.sub_blocks.items():
+ blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
+ blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
+ blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
+ result_blocks.update(blocks_to_update)
+ else:
+ # PipelineBlock
+ result_blocks[block_name] = block
+ # Add this block's output names to active triggers if defined
+ if hasattr(block, "outputs"):
+ active_triggers.update(out.name for out in block.outputs)
+ return result_blocks
+
+ # auto
+ else:
+ # Find first block_trigger_input that matches any value in our active_triggers
+ this_block = None
+ for trigger_input in block.block_trigger_inputs:
+ if trigger_input is not None and trigger_input in active_triggers:
+ this_block = block.trigger_to_block_map[trigger_input]
+ break
+
+ # If no matches found, try to get the default (None) block
+ if this_block is None and None in block.block_trigger_inputs:
+ this_block = block.trigger_to_block_map[None]
+
+ if this_block is not None:
+ # sequential/auto (keep traversing)
+ if this_block.sub_blocks:
+ result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
+ else:
+ # PipelineBlock
+ result_blocks[block_name] = this_block
+ # Add this block's output names to active triggers if defined
+ # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
+ if hasattr(this_block, "outputs"):
+ active_triggers.update(out.name for out in this_block.outputs)
+
+ return result_blocks
+
+ all_blocks = OrderedDict()
+ for block_name, block in self.sub_blocks.items():
+ blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
+ all_blocks.update(blocks_to_update)
+ return all_blocks
+
+ def get_execution_blocks(self, *trigger_inputs):
+ trigger_inputs_all = self.trigger_inputs
+
+ if trigger_inputs is not None:
+ if not isinstance(trigger_inputs, (list, tuple, set)):
+ trigger_inputs = [trigger_inputs]
+ invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
+ if invalid_inputs:
+ logger.warning(
+ f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}"
+ )
+ trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all]
+
+ if trigger_inputs is None:
+ if None in trigger_inputs_all:
+ trigger_inputs = [None]
+ else:
+ trigger_inputs = [trigger_inputs_all[0]]
+ blocks_triggered = self._traverse_trigger_blocks(trigger_inputs)
+ return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
+
+ def __repr__(self):
+ class_name = self.__class__.__name__
+ base_class = self.__class__.__bases__[0].__name__
+ header = (
+ f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
+ )
+
+ if self.trigger_inputs:
+ header += "\n"
+ header += " " + "=" * 100 + "\n"
+ header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
+ header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
+ # Get first trigger input as example
+ example_input = next(t for t in self.trigger_inputs if t is not None)
+ header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
+ header += " " + "=" * 100 + "\n\n"
+
+ # Format description with proper indentation
+ desc_lines = self.description.split("\n")
+ desc = []
+ # First line with "Description:" label
+ desc.append(f" Description: {desc_lines[0]}")
+ # Subsequent lines with proper indentation
+ if len(desc_lines) > 1:
+ desc.extend(f" {line}" for line in desc_lines[1:])
+ desc = "\n".join(desc) + "\n"
+
+ # Components section - focus only on expected components
+ expected_components = getattr(self, "expected_components", [])
+ components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
+
+ # Configs section - use format_configs with add_empty_lines=False
+ expected_configs = getattr(self, "expected_configs", [])
+ configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
+
+ # Blocks section - moved to the end with simplified format
+ blocks_str = " Sub-Blocks:\n"
+ for i, (name, block) in enumerate(self.sub_blocks.items()):
+ # Get trigger input for this block
+ trigger = None
+ if hasattr(self, "block_to_trigger_map"):
+ trigger = self.block_to_trigger_map.get(name)
+ # Format the trigger info
+ if trigger is None:
+ trigger_str = "[default]"
+ elif isinstance(trigger, (list, tuple)):
+ trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
+ else:
+ trigger_str = f"[trigger: {trigger}]"
+ # For AutoPipelineBlocks, add bullet points
+ blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
+ else:
+ # For SequentialPipelineBlocks, show execution order
+ blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
+
+ # Add block description
+ desc_lines = block.description.split("\n")
+ indented_desc = desc_lines[0]
+ if len(desc_lines) > 1:
+ indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
+ blocks_str += f" Description: {indented_desc}\n\n"
+
+ # Build the representation with conditional sections
+ result = f"{header}\n{desc}"
+
+ # Only add components section if it has content
+ if components_str.strip():
+ result += f"\n\n{components_str}"
+
+ # Only add configs section if it has content
+ if configs_str.strip():
+ result += f"\n\n{configs_str}"
+
+ # Always add blocks section
+ result += f"\n\n{blocks_str})"
+
+ return result
+
+ @property
+ def doc(self):
+ return make_doc_string(
+ self.inputs,
+ self.outputs,
+ self.description,
+ class_name=self.__class__.__name__,
+ expected_components=self.expected_components,
+ expected_configs=self.expected_configs,
+ )
+
+
+class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
+ """
+ A Pipeline blocks that combines multiple pipeline block classes into a For Loop. When called, it will call each
+ block in sequence.
+
+ This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipeline blocks (such as loading or saving etc.)
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+
+ Attributes:
+ block_classes: List of block classes to be used
+ block_names: List of prefixes for each block
+ """
+
+ model_name = None
+ block_classes = []
+ block_names = []
+
+ @property
+ def description(self) -> str:
+ """Description of the block. Must be implemented by subclasses."""
+ raise NotImplementedError("description method must be implemented in subclasses")
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return []
+
+ @property
+ def loop_expected_configs(self) -> List[ConfigSpec]:
+ return []
+
+ @property
+ def loop_inputs(self) -> List[InputParam]:
+ """List of input parameters. Must be implemented by subclasses."""
+ return []
+
+ @property
+ def loop_required_inputs(self) -> List[str]:
+ input_names = []
+ for input_param in self.loop_inputs:
+ if input_param.required:
+ input_names.append(input_param.name)
+ return input_names
+
+ @property
+ def loop_intermediate_outputs(self) -> List[OutputParam]:
+ """List of intermediate output parameters. Must be implemented by subclasses."""
+ return []
+
+ # modified from SequentialPipelineBlocks to include loop_expected_components
+ @property
+ def expected_components(self):
+ expected_components = []
+ for block in self.sub_blocks.values():
+ for component in block.expected_components:
+ if component not in expected_components:
+ expected_components.append(component)
+ for component in self.loop_expected_components:
+ if component not in expected_components:
+ expected_components.append(component)
+ return expected_components
+
+ # modified from SequentialPipelineBlocks to include loop_expected_configs
+ @property
+ def expected_configs(self):
+ expected_configs = []
+ for block in self.sub_blocks.values():
+ for config in block.expected_configs:
+ if config not in expected_configs:
+ expected_configs.append(config)
+ for config in self.loop_expected_configs:
+ if config not in expected_configs:
+ expected_configs.append(config)
+ return expected_configs
+
+ def _get_inputs(self):
+ inputs = []
+ inputs.extend(self.loop_inputs)
+ outputs = set()
+
+ for name, block in self.sub_blocks.items():
+ # Add inputs that aren't in outputs yet
+ for inp in block.inputs:
+ if inp.name not in outputs and inp not in inputs:
+ inputs.append(inp)
+
+ # Only add outputs if the block cannot be skipped
+ should_add_outputs = True
+ if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
+ should_add_outputs = False
+
+ if should_add_outputs:
+ # Add this block's outputs
+ block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
+ outputs.update(block_intermediate_outputs)
+
+ for input_param in inputs:
+ if input_param.name in self.required_inputs:
+ input_param.required = True
+ else:
+ input_param.required = False
+
+ return inputs
+
+ @property
+ # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
+ def inputs(self):
+ return self._get_inputs()
+
+ # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
+ @property
+ def required_inputs(self) -> List[str]:
+ # Get the first block from the dictionary
+ first_block = next(iter(self.sub_blocks.values()))
+ required_by_any = set(getattr(first_block, "required_inputs", set()))
+
+ required_by_loop = set(getattr(self, "loop_required_inputs", set()))
+ required_by_any.update(required_by_loop)
+
+ # Union with required inputs from all other blocks
+ for block in list(self.sub_blocks.values())[1:]:
+ block_required = set(getattr(block, "required_inputs", set()))
+ required_by_any.update(block_required)
+
+ return list(required_by_any)
+
+ # YiYi TODO: this need to be thought about more
+ # modified from SequentialPipelineBlocks to include loop_intermediate_outputs
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
+ combined_outputs = self.combine_outputs(*named_outputs)
+ for output in self.loop_intermediate_outputs:
+ if output.name not in {output.name for output in combined_outputs}:
+ combined_outputs.append(output)
+ return combined_outputs
+
+ # YiYi TODO: this need to be thought about more
+ @property
+ def outputs(self) -> List[str]:
+ return next(reversed(self.sub_blocks.values())).intermediate_outputs
+
+ def __init__(self):
+ sub_blocks = InsertableDict()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
+ self.sub_blocks = sub_blocks
+
+ @classmethod
+ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
+ """
+ Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks.
+
+ Args:
+ blocks_dict: Dictionary mapping block names to block instances
+
+ Returns:
+ A new LoopSequentialPipelineBlocks instance
+ """
+ instance = cls()
+
+ # Create instances if classes are provided
+ sub_blocks = InsertableDict()
+ for name, block in blocks_dict.items():
+ if inspect.isclass(block):
+ sub_blocks[name] = block()
+ else:
+ sub_blocks[name] = block
+
+ instance.block_classes = [block.__class__ for block in blocks_dict.values()]
+ instance.block_names = list(blocks_dict.keys())
+ instance.sub_blocks = blocks_dict
+ return instance
+
+ def loop_step(self, components, state: PipelineState, **kwargs):
+ for block_name, block in self.sub_blocks.items():
+ try:
+ components, state = block(components, state, **kwargs)
+ except Exception as e:
+ error_msg = (
+ f"\nError in block: ({block_name}, {block.__class__.__name__})\n"
+ f"Error details: {str(e)}\n"
+ f"Traceback:\n{traceback.format_exc()}"
+ )
+ logger.error(error_msg)
+ raise
+ return components, state
+
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
+
+ @property
+ def doc(self):
+ return make_doc_string(
+ self.inputs,
+ self.outputs,
+ self.description,
+ class_name=self.__class__.__name__,
+ expected_components=self.expected_components,
+ expected_configs=self.expected_configs,
+ )
+
+ # modified from SequentialPipelineBlocks,
+ # (does not need trigger_inputs related part so removed them,
+ # do not need to support auto block for loop blocks)
+ def __repr__(self):
+ class_name = self.__class__.__name__
+ base_class = self.__class__.__bases__[0].__name__
+ header = (
+ f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
+ )
+
+ # Format description with proper indentation
+ desc_lines = self.description.split("\n")
+ desc = []
+ # First line with "Description:" label
+ desc.append(f" Description: {desc_lines[0]}")
+ # Subsequent lines with proper indentation
+ if len(desc_lines) > 1:
+ desc.extend(f" {line}" for line in desc_lines[1:])
+ desc = "\n".join(desc) + "\n"
+
+ # Components section - focus only on expected components
+ expected_components = getattr(self, "expected_components", [])
+ components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
+
+ # Configs section - use format_configs with add_empty_lines=False
+ expected_configs = getattr(self, "expected_configs", [])
+ configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
+
+ # Blocks section - moved to the end with simplified format
+ blocks_str = " Sub-Blocks:\n"
+ for i, (name, block) in enumerate(self.sub_blocks.items()):
+ # For SequentialPipelineBlocks, show execution order
+ blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
+
+ # Add block description
+ desc_lines = block.description.split("\n")
+ indented_desc = desc_lines[0]
+ if len(desc_lines) > 1:
+ indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
+ blocks_str += f" Description: {indented_desc}\n\n"
+
+ # Build the representation with conditional sections
+ result = f"{header}\n{desc}"
+
+ # Only add components section if it has content
+ if components_str.strip():
+ result += f"\n\n{components_str}"
+
+ # Only add configs section if it has content
+ if configs_str.strip():
+ result += f"\n\n{configs_str}"
+
+ # Always add blocks section
+ result += f"\n\n{blocks_str})"
+
+ return result
+
+ @torch.compiler.disable
+ def progress_bar(self, iterable=None, total=None):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ if iterable is not None:
+ return tqdm(iterable, **self._progress_bar_config)
+ elif total is not None:
+ return tqdm(total=total, **self._progress_bar_config)
+ else:
+ raise ValueError("Either `total` or `iterable` has to be defined.")
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
+
+
+# YiYi TODO:
+# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
+# 2. do we need ConfigSpec? the are basically just key/val kwargs
+# 3. imnprove docstring and potentially add validator for methods where we accept kwargs to be passed to from_pretrained/save_pretrained/load_components()
+class ModularPipeline(ConfigMixin, PushToHubMixin):
+ """
+ Base class for all Modular pipelines.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+
+ Args:
+ blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
+ """
+
+ config_name = "modular_model_index.json"
+ hf_device_map = None
+ default_blocks_name = None
+
+ # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
+ def __init__(
+ self,
+ blocks: Optional[ModularPipelineBlocks] = None,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ components_manager: Optional[ComponentsManager] = None,
+ collection: Optional[str] = None,
+ modular_config_dict: Optional[Dict[str, Any]] = None,
+ config_dict: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ):
+ """
+ Initialize a ModularPipeline instance.
+
+ This method sets up the pipeline by:
+ - creating default pipeline blocks if not provided
+ - gather component and config specifications based on the pipeline blocks's requirement (e.g.
+ expected_components, expected_configs)
+ - update the loading specs of from_pretrained components based on the modular_model_index.json file from
+ huggingface hub if `pretrained_model_name_or_path` is provided
+ - create defaultfrom_config components and register everything
+
+ Args:
+ blocks: `ModularPipelineBlocks` instance. If None, will attempt to load
+ default blocks based on the pipeline class name.
+ pretrained_model_name_or_path: Path to a pretrained pipeline configuration. Can be None if the pipeline
+ does not require any additional loading config. If provided, will first try to load component specs
+ (only for from_pretrained components) and config values from `modular_model_index.json`, then
+ fallback to `model_index.json` for compatibility with standard non-modular repositories.
+ components_manager:
+ Optional ComponentsManager for managing multiple component cross different pipelines and apply
+ offloading strategies.
+ collection: Optional collection name for organizing components in the ComponentsManager.
+ **kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration.
+
+ Examples:
+ ```python
+ # Initialize with custom blocks
+ pipeline = ModularPipeline(blocks=my_custom_blocks)
+
+ # Initialize from pretrained configuration
+ pipeline = ModularPipeline(blocks=my_blocks, pretrained_model_name_or_path="my-repo/modular-pipeline")
+
+ # Initialize with components manager
+ pipeline = ModularPipeline(
+ blocks=my_blocks, components_manager=ComponentsManager(), collection="my_collection"
+ )
+ ```
+
+ Notes:
+ - If blocks is None, the method will try to find default blocks based on the pipeline class name
+ - Components with default_creation_method="from_config" are created immediately, its specs are not included
+ in config dict and will not be saved in `modular_model_index.json`
+ - Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
+ `load_components()` (with or without specific component names)
+ - The pipeline's config dict is populated with component specs (only for from_pretrained components) and
+ config values, which will be saved as `modular_model_index.json` during `save_pretrained`
+ - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
+ `_blocks_class_name` in the config dict
+ """
+
+ if modular_config_dict is None and config_dict is None and pretrained_model_name_or_path is not None:
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+
+ load_config_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "token": token,
+ "local_files_only": local_files_only,
+ "revision": revision,
+ }
+
+ modular_config_dict, config_dict = self._load_pipeline_config(
+ pretrained_model_name_or_path, **load_config_kwargs
+ )
+
+ if blocks is None:
+ if modular_config_dict is not None:
+ blocks_class_name = modular_config_dict.get("_blocks_class_name")
+ elif config_dict is not None:
+ blocks_class_name = self.get_default_blocks_name(config_dict)
+ else:
+ blocks_class_name = None
+ if blocks_class_name is not None:
+ diffusers_module = importlib.import_module("diffusers")
+ blocks_class = getattr(diffusers_module, blocks_class_name)
+ blocks = blocks_class()
+ else:
+ logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
+
+ self.blocks = blocks
+ self._components_manager = components_manager
+ self._collection = collection
+ self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
+ self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
+
+ # update component_specs and config_specs based on modular_model_index.json
+ if modular_config_dict is not None:
+ for name, value in modular_config_dict.items():
+ # all the components in modular_model_index.json are from_pretrained components
+ if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
+ library, class_name, component_spec_dict = value
+ component_spec = self._dict_to_component_spec(name, component_spec_dict)
+ component_spec.default_creation_method = "from_pretrained"
+ self._component_specs[name] = component_spec
+
+ elif name in self._config_specs:
+ self._config_specs[name].default = value
+
+ # if `modular_config_dict` is None (i.e. `modular_model_index.json` is not found), update based on `config_dict` (i.e. `model_index.json`)
+ elif config_dict is not None:
+ for name, value in config_dict.items():
+ if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
+ library, class_name = value
+ component_spec_dict = {
+ "repo": pretrained_model_name_or_path,
+ "subfolder": name,
+ "type_hint": (library, class_name),
+ }
+ component_spec = self._dict_to_component_spec(name, component_spec_dict)
+ component_spec.default_creation_method = "from_pretrained"
+ self._component_specs[name] = component_spec
+ elif name in self._config_specs:
+ self._config_specs[name].default = value
+
+ if len(kwargs) > 0:
+ logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.")
+
+ register_components_dict = {}
+ for name, component_spec in self._component_specs.items():
+ if component_spec.default_creation_method == "from_config":
+ component = component_spec.create()
+ else:
+ component = None
+ register_components_dict[name] = component
+ self.register_components(**register_components_dict)
+
+ default_configs = {}
+ for name, config_spec in self._config_specs.items():
+ default_configs[name] = config_spec.default
+ self.register_to_config(**default_configs)
+
+ self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
+
+ @property
+ def default_call_parameters(self) -> Dict[str, Any]:
+ """
+ Returns:
+ - Dictionary mapping input names to their default values
+ """
+ params = {}
+ for input_param in self.blocks.inputs:
+ params[input_param.name] = input_param.default
+ return params
+
+ def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
+ return self.default_blocks_name
+
+ @classmethod
+ def _load_pipeline_config(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ **load_config_kwargs,
+ ):
+ try:
+ # try to load modular_model_index.json
+ modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
+ return modular_config_dict, None
+
+ except EnvironmentError as e:
+ logger.debug(f" modular_model_index.json not found in the repo: {e}")
+
+ try:
+ logger.debug(" try to load model_index.json")
+ from diffusers import DiffusionPipeline
+
+ config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
+ return None, config_dict
+
+ except EnvironmentError as e:
+ logger.debug(f" model_index.json not found in the repo: {e}")
+
+ return None, None
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ trust_remote_code: Optional[bool] = None,
+ components_manager: Optional[ComponentsManager] = None,
+ collection: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Load a ModularPipeline from a huggingface hub repo.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
+ Path to a pretrained pipeline configuration. It will first try to load config from
+ `modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
+ non-modular repositories. If the pretrained_model_name_or_path does not contain any pipeline config, it
+ will be set to None during initialization.
+ trust_remote_code (`bool`, optional):
+ Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
+ pipeline blocks based on the custom code in `pretrained_model_name_or_path`
+ components_manager (`ComponentsManager`, optional):
+ ComponentsManager instance for managing multiple component cross different pipelines and apply
+ offloading strategies.
+ collection (`str`, optional):`
+ Collection name for organizing components in the ComponentsManager.
+ """
+ from ..pipelines.pipeline_loading_utils import _get_pipeline_class
+
+ try:
+ blocks = ModularPipelineBlocks.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ except EnvironmentError as e:
+ logger.debug(f"EnvironmentError: {e}")
+ blocks = None
+
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+
+ load_config_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "token": token,
+ "local_files_only": local_files_only,
+ "revision": revision,
+ }
+
+ modular_config_dict, config_dict = cls._load_pipeline_config(
+ pretrained_model_name_or_path, **load_config_kwargs
+ )
+
+ if modular_config_dict is not None:
+ pipeline_class = _get_pipeline_class(cls, config=modular_config_dict)
+ elif config_dict is not None:
+ from diffusers.pipelines.auto_pipeline import _get_model
+
+ logger.debug(" try to determine the modular pipeline class from model_index.json")
+ standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
+ model_name = _get_model(standard_pipeline_class.__name__)
+ pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
+ diffusers_module = importlib.import_module("diffusers")
+ pipeline_class = getattr(diffusers_module, pipeline_class_name)
+ else:
+ # there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
+ pipeline_class = cls
+ pretrained_model_name_or_path = None
+
+ pipeline = pipeline_class(
+ blocks=blocks,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ components_manager=components_manager,
+ collection=collection,
+ modular_config_dict=modular_config_dict,
+ config_dict=config_dict,
+ **kwargs,
+ )
+ return pipeline
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save the pipeline to a directory. It does not save components, you need to save them separately.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Path to the directory where the pipeline will be saved.
+ push_to_hub (`bool`, optional):
+ Whether to push the pipeline to the huggingface hub.
+ **kwargs: Additional arguments passed to `save_config()` method
+ """
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ private = kwargs.pop("private", None)
+ create_pr = kwargs.pop("create_pr", False)
+ token = kwargs.pop("token", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
+
+ # Create a new empty model card and eventually tag it
+ model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
+ model_card = populate_model_card(model_card)
+ model_card.save(os.path.join(save_directory, "README.md"))
+
+ # YiYi TODO: maybe order the json file to make it more readable: configs first, then components
+ self.save_config(save_directory=save_directory)
+
+ if push_to_hub:
+ self._upload_folder(
+ save_directory,
+ repo_id,
+ token=token,
+ commit_message=commit_message,
+ create_pr=create_pr,
+ )
+
+ @property
+ def doc(self):
+ """
+ Returns:
+ - The docstring of the pipeline blocks
+ """
+ return self.blocks.doc
+
+ def register_components(self, **kwargs):
+ """
+ Register components with their corresponding specifications.
+
+ This method is responsible for:
+ 1. Sets component objects as attributes on the loader (e.g., self.unet = unet)
+ 2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only
+ for from_pretrained components)
+ 3. Adds components to the component manager if one is attached (only for from_pretrained components)
+
+ This method is called when:
+ - Components are first initialized in __init__:
+ - from_pretrained components not loaded during __init__ so they are registered as None;
+ - non from_pretrained components are created during __init__ and registered as the object itself
+ - Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
+ loader.update_components(guider=guider_spec)
+ - (from_pretrained) Components are loaded with the `load_components()` method: e.g.
+ loader.load_components(names=["unet"]) or loader.load_components() to load all default components
+
+ Args:
+ **kwargs: Keyword arguments where keys are component names and values are component objects.
+ E.g., register_components(unet=unet_model, text_encoder=encoder_model)
+
+ Notes:
+ - When registering None for a component, it sets attribute to None but still syncs specs with the config
+ dict, which will be saved as `modular_model_index.json` during `save_pretrained`
+ - component_specs are updated to match the new component outside of this method, e.g. in
+ `update_components()` method
+ """
+ for name, module in kwargs.items():
+ # current component spec
+ component_spec = self._component_specs.get(name)
+ if component_spec is None:
+ logger.warning(f"ModularPipeline.register_components: skipping unknown component '{name}'")
+ continue
+
+ # check if it is the first time registration, i.e. calling from __init__
+ is_registered = hasattr(self, name)
+ is_from_pretrained = component_spec.default_creation_method == "from_pretrained"
+
+ if module is not None:
+ # actual library and class name of the module
+ library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel")
+ else:
+ # if module is None, e.g. self.register_components(unet=None) during __init__
+ # we do not update the spec,
+ # but we still need to update the modular_model_index.json config based on component spec
+ library, class_name = None, None
+
+ # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
+ # e.g. {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1",
+ # "type_hint": ("diffusers", "UNet2DConditionModel"),
+ # "subfolder": "unet",
+ # "variant": None,
+ # "revision": None}
+ component_spec_dict = self._component_spec_to_dict(component_spec)
+
+ register_dict = {name: (library, class_name, component_spec_dict)}
+
+ # set the component as attribute
+ # if it is not set yet, just set it and skip the process to check and warn below
+ if not is_registered:
+ if is_from_pretrained:
+ self.register_to_config(**register_dict)
+ setattr(self, name, module)
+ if module is not None and is_from_pretrained and self._components_manager is not None:
+ self._components_manager.add(name, module, self._collection)
+ continue
+
+ current_module = getattr(self, name, None)
+ # skip if the component is already registered with the same object
+ if current_module is module:
+ logger.info(
+ f"ModularPipeline.register_components: {name} is already registered with same object, skipping"
+ )
+ continue
+
+ # warn if unregister
+ if current_module is not None and module is None:
+ logger.info(
+ f"ModularPipeline.register_components: setting '{name}' to None "
+ f"(was {current_module.__class__.__name__})"
+ )
+ # same type, new instance → replace but send debug log
+ elif (
+ current_module is not None
+ and module is not None
+ and isinstance(module, current_module.__class__)
+ and current_module != module
+ ):
+ logger.debug(
+ f"ModularPipeline.register_components: replacing existing '{name}' "
+ f"(same type {type(current_module).__name__}, new instance)"
+ )
+
+ # update modular_model_index.json config
+ if is_from_pretrained:
+ self.register_to_config(**register_dict)
+ # finally set models
+ setattr(self, name, module)
+ # add to component manager if one is attached
+ if module is not None and is_from_pretrained and self._components_manager is not None:
+ self._components_manager.add(name, module, self._collection)
+
+ @property
+ def device(self) -> torch.device:
+ r"""
+ Returns:
+ `torch.device`: The torch device on which the pipeline is located.
+ """
+ modules = self.components.values()
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ for module in modules:
+ return module.device
+
+ return torch.device("cpu")
+
+ @property
+ # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
+ Accelerate's module hooks.
+ """
+ for name, model in self.components.items():
+ if not isinstance(model, torch.nn.Module):
+ continue
+
+ if not hasattr(model, "_hf_hook"):
+ return self.device
+ for module in model.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ r"""
+ Returns:
+ `torch.dtype`: The torch dtype on which the pipeline is located.
+ """
+ modules = self.components.values()
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ for module in modules:
+ return module.dtype
+
+ return torch.float32
+
+ @property
+ def null_component_names(self) -> List[str]:
+ """
+ Returns:
+ - List of names for components that needs to be loaded
+ """
+ return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None]
+
+ @property
+ def component_names(self) -> List[str]:
+ """
+ Returns:
+ - List of names for all components
+ """
+ return list(self.components.keys())
+
+ @property
+ def pretrained_component_names(self) -> List[str]:
+ """
+ Returns:
+ - List of names for from_pretrained components
+ """
+ return [
+ name
+ for name in self._component_specs.keys()
+ if self._component_specs[name].default_creation_method == "from_pretrained"
+ ]
+
+ @property
+ def config_component_names(self) -> List[str]:
+ """
+ Returns:
+ - List of names for from_config components
+ """
+ return [
+ name
+ for name in self._component_specs.keys()
+ if self._component_specs[name].default_creation_method == "from_config"
+ ]
+
+ @property
+ def components(self) -> Dict[str, Any]:
+ """
+ Returns:
+ - Dictionary mapping component names to their objects (include both from_pretrained and from_config
+ components)
+ """
+ # return only components we've actually set as attributes on self
+ return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)}
+
+ def get_component_spec(self, name: str) -> ComponentSpec:
+ """
+ Returns:
+ - a copy of the ComponentSpec object for the given component name
+ """
+ return deepcopy(self._component_specs[name])
+
+ def update_components(self, **kwargs):
+ """
+ Update components and configuration values and specs after the pipeline has been instantiated.
+
+ This method allows you to:
+ 1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`)
+ 2. Update configuration values (e.g., changing `self.requires_safety_checker` flag)
+
+ In addition to updating the components and configuration values as pipeline attributes, the method also
+ updates:
+ - the corresponding specs in `_component_specs` and `_config_specs`
+ - the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
+
+ Args:
+ **kwargs: Component objects, ComponentSpec objects, or configuration values to update:
+ - Component objects: Only supports components we can extract specs using
+ `ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or
+ ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`)
+ - ComponentSpec objects: Only supports default_creation_method == "from_config", will call create()
+ method to create a new component (e.g., `guider=ComponentSpec(name="guider",
+ type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
+ - Configuration values: Simple values to update configuration settings (e.g.,
+ `requires_safety_checker=False`)
+
+ Raises:
+ ValueError: If a component object is not supported in ComponentSpec.from_component() method:
+ - nn.Module components without a valid `_diffusers_load_id` attribute
+ - Non-ConfigMixin components without a valid `_diffusers_load_id` attribute
+
+ Examples:
+ ```python
+ # Update multiple components at once
+ pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder)
+
+ # Update configuration values
+ pipeline.update_components(requires_safety_checker=False)
+
+ # Update both components and configs together
+ pipeline.update_components(unet=new_unet_model, requires_safety_checker=False)
+
+ # Update with ComponentSpec objects (from_config only)
+ pipeline.update_components(
+ guider=ComponentSpec(
+ name="guider",
+ type_hint=ClassifierFreeGuidance,
+ config={"guidance_scale": 5.0},
+ default_creation_method="from_config",
+ )
+ )
+ ```
+
+ Notes:
+ - Components with trained weights must be created using ComponentSpec.load(). If the component has not been
+ shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
+ - ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
+ - ComponentSpec objects with default_creation_method="from_pretrained" are not supported in
+ update_components()
+ """
+
+ # extract component_specs_updates & config_specs_updates from `specs`
+ passed_component_specs = {
+ k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)
+ }
+ passed_components = {
+ k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)
+ }
+ passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
+
+ for name, component in passed_components.items():
+ current_component_spec = self._component_specs[name]
+
+ # log if type changed
+ if current_component_spec.type_hint is not None and not isinstance(
+ component, current_component_spec.type_hint
+ ):
+ logger.info(
+ f"ModularPipeline.update_components: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
+ )
+ # update _component_specs based on the new component
+ if component is None:
+ new_component_spec = current_component_spec
+ if hasattr(self, name) and getattr(self, name) is not None:
+ logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)")
+ elif current_component_spec.default_creation_method == "from_pretrained" and not (
+ hasattr(component, "_diffusers_load_id") and component._diffusers_load_id is not None
+ ):
+ logger.warning(
+ f"ModularPipeline.update_components: {name} has no valid _diffusers_load_id. "
+ f"This will result in empty loading spec, use ComponentSpec.load() for proper specs"
+ )
+ new_component_spec = ComponentSpec(name=name, type_hint=type(component))
+ else:
+ new_component_spec = ComponentSpec.from_component(name, component)
+
+ if new_component_spec.default_creation_method != current_component_spec.default_creation_method:
+ logger.info(
+ f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}."
+ )
+
+ self._component_specs[name] = new_component_spec
+
+ if len(kwargs) > 0:
+ logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
+
+ created_components = {}
+ for name, component_spec in passed_component_specs.items():
+ if component_spec.default_creation_method == "from_pretrained":
+ raise ValueError(
+ "ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method"
+ )
+ created_components[name] = component_spec.create()
+ current_component_spec = self._component_specs[name]
+ # warn if type changed
+ if current_component_spec.type_hint is not None and not isinstance(
+ created_components[name], current_component_spec.type_hint
+ ):
+ logger.info(
+ f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
+ )
+ # update _component_specs based on the user passed component_spec
+ self._component_specs[name] = component_spec
+ self.register_components(**passed_components, **created_components)
+
+ config_to_register = {}
+ for name, new_value in passed_config_values.items():
+ # e.g. requires_aesthetics_score = False
+ self._config_specs[name].default = new_value
+ config_to_register[name] = new_value
+ self.register_to_config(**config_to_register)
+
+ # YiYi TODO: support map for additional from_pretrained kwargs
+ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
+ """
+ Load selected components from specs.
+
+ Args:
+ names: List of component names to load. If None, will load all components with
+ default_creation_method == "from_pretrained". If provided as a list or string, will load only the
+ specified components.
+ **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
+ - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
+ - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
+ - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
+ `pretrained_model_name_or_path`, `variant`, `revision`, etc.
+ - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
+ `pretrained_model_name_or_path`, `variant`, `revision`, etc.
+ """
+
+ if names is None:
+ names = [
+ name
+ for name in self._component_specs.keys()
+ if self._component_specs[name].default_creation_method == "from_pretrained"
+ ]
+ elif isinstance(names, str):
+ names = [names]
+ elif not isinstance(names, list):
+ raise ValueError(f"Invalid type for names: {type(names)}")
+
+ components_to_load = {name for name in names if name in self._component_specs}
+ unknown_names = {name for name in names if name not in self._component_specs}
+ if len(unknown_names) > 0:
+ logger.warning(f"Unknown components will be ignored: {unknown_names}")
+
+ components_to_register = {}
+ for name in components_to_load:
+ spec = self._component_specs[name]
+ component_load_kwargs = {}
+ for key, value in kwargs.items():
+ if not isinstance(value, dict):
+ # if the value is a single value, apply it to all components
+ component_load_kwargs[key] = value
+ else:
+ if name in value:
+ # if it is a dict, check if the component name is in the dict
+ component_load_kwargs[key] = value[name]
+ elif "default" in value:
+ # check if the default is specified
+ component_load_kwargs[key] = value["default"]
+ try:
+ components_to_register[name] = spec.load(**component_load_kwargs)
+ except Exception:
+ logger.warning(
+ f"\nFailed to create component {name}:\n"
+ f"- Component spec: {spec}\n"
+ f"- load() called with kwargs: {component_load_kwargs}\n"
+ "If this component is not required for your workflow you can safely ignore this message.\n\n"
+ "Traceback:\n"
+ f"{traceback.format_exc()}"
+ )
+
+ # Register all components at once
+ self.register_components(**components_to_register)
+
+ # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._maybe_raise_error_if_group_offload_active
+ def _maybe_raise_error_if_group_offload_active(
+ self, raise_error: bool = False, module: Optional[torch.nn.Module] = None
+ ) -> bool:
+ from ..hooks.group_offloading import _is_group_offload_enabled
+
+ components = self.components.values() if module is None else [module]
+ components = [component for component in components if isinstance(component, torch.nn.Module)]
+ for component in components:
+ if _is_group_offload_enabled(component):
+ if raise_error:
+ raise ValueError(
+ "You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
+ "with group offloading enabled. This is not supported. Please disable group offloading for "
+ "components of the pipeline to use other offloading methods."
+ )
+ return True
+ return False
+
+ # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
+ def to(self, *args, **kwargs) -> Self:
+ r"""
+ Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
+ arguments of `self.to(*args, **kwargs).`
+
+ > [!TIP] > If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is.
+ Otherwise, > the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
+
+
+ Here are the ways to call `to`:
+
+ - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
+ - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
+ [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
+ - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
+ specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
+
+ Arguments:
+ dtype (`torch.dtype`, *optional*):
+ Returns a pipeline with the specified
+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
+ device (`torch.Device`, *optional*):
+ Returns a pipeline with the specified
+ [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
+ silence_dtype_warnings (`str`, *optional*, defaults to `False`):
+ Whether to omit warnings if the target `dtype` is not compatible with the target `device`.
+
+ Returns:
+ [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
+ """
+ from ..pipelines.pipeline_utils import _check_bnb_status
+ from ..utils import is_accelerate_available, is_accelerate_version, is_hpu_available, is_transformers_version
+
+ dtype = kwargs.pop("dtype", None)
+ device = kwargs.pop("device", None)
+ silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
+
+ dtype_arg = None
+ device_arg = None
+ if len(args) == 1:
+ if isinstance(args[0], torch.dtype):
+ dtype_arg = args[0]
+ else:
+ device_arg = torch.device(args[0]) if args[0] is not None else None
+ elif len(args) == 2:
+ if isinstance(args[0], torch.dtype):
+ raise ValueError(
+ "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
+ )
+ device_arg = torch.device(args[0]) if args[0] is not None else None
+ dtype_arg = args[1]
+ elif len(args) > 2:
+ raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")
+
+ if dtype is not None and dtype_arg is not None:
+ raise ValueError(
+ "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
+ )
+
+ dtype = dtype or dtype_arg
+
+ if device is not None and device_arg is not None:
+ raise ValueError(
+ "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
+ )
+
+ device = device or device_arg
+ device_type = torch.device(device).type if device is not None else None
+ pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
+
+ # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
+ def module_is_sequentially_offloaded(module):
+ if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
+ return False
+
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
+
+ if is_loaded_in_8bit_bnb:
+ return False
+
+ return hasattr(module, "_hf_hook") and (
+ isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
+ or hasattr(module._hf_hook, "hooks")
+ and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
+ )
+
+ def module_is_offloaded(module):
+ if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
+ return False
+
+ return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
+
+ # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
+ pipeline_is_sequentially_offloaded = any(
+ module_is_sequentially_offloaded(module) for _, module in self.components.items()
+ )
+
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
+ if is_pipeline_device_mapped:
+ raise ValueError(
+ "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
+ )
+
+ if device_type in ["cuda", "xpu"]:
+ if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
+ raise ValueError(
+ "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
+ )
+ # PR: https://github.com/huggingface/accelerate/pull/3223/
+ elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
+ raise ValueError(
+ "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
+ )
+
+ # Display a warning in this case (the operation succeeds but the benefits are lost)
+ pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
+ if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
+ logger.warning(
+ f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
+ )
+
+ # Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
+ if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
+ os.environ["PT_HPU_GPU_MIGRATION"] = "1"
+ logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
+
+ import habana_frameworks.torch # noqa: F401
+
+ # HPU hardware check
+ if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
+ raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
+
+ os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
+ logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
+
+ modules = self.components.values()
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
+ for module in modules:
+ _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
+ is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)
+
+ if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
+ logger.warning(
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
+ )
+
+ if is_loaded_in_8bit_bnb and device is not None:
+ logger.warning(
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
+ )
+
+ # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
+ # components can be from outside diffusers too, but still have group offloading enabled.
+ if (
+ self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
+ and device is not None
+ ):
+ logger.warning(
+ f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
+ )
+
+ # This can happen for `transformer` models. CPU placement was added in
+ # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
+ if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
+ module.to(device=device)
+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
+ module.to(device, dtype)
+
+ if (
+ module.dtype == torch.float16
+ and str(device) in ["cpu"]
+ and not silence_dtype_warnings
+ and not is_offloaded
+ ):
+ logger.warning(
+ "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
+ " is not recommended to move them to `cpu` as running them will fail. Please make"
+ " sure to use an accelerator to run the pipeline in inference, due to the lack of"
+ " support for`float16` operations on this device in PyTorch. Please, remove the"
+ " `torch_dtype=torch.float16` argument, or use another device for inference."
+ )
+ return self
+
+ @staticmethod
+ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
+ """
+ Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`. If
+ the `default_creation_method` is not `from_pretrained`, return None.
+
+ This dict contains:
+ - "type_hint": Tuple[str, str]
+ Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
+ - All loading fields defined by `component_spec.loading_fields()`, typically:
+ - "pretrained_model_name_or_path": Optional[str]
+ The model pretrained_model_name_or_pathsitory (e.g., "stabilityai/stable-diffusion-xl").
+ - "subfolder": Optional[str]
+ A subfolder within the pretrained_model_name_or_path where this component lives.
+ - "variant": Optional[str]
+ An optional variant identifier for the model.
+ - "revision": Optional[str]
+ A specific git revision (commit hash, tag, or branch).
+ - ... any other loading fields defined on the spec.
+
+ Args:
+ component_spec (ComponentSpec):
+ The spec object describing one pipeline component.
+
+ Returns:
+ Dict[str, Any]: A mapping suitable for JSON serialization.
+
+ Example:
+ >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
+ UNet2DConditionModel >>> spec = ComponentSpec(
+ ... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ...
+ pretrained_model_name_or_path="path/to/pretrained_model_name_or_path", ... subfolder="subfolder", ...
+ variant=None, ... revision=None, ... default_creation_method="from_pretrained",
+ ... ) >>> ModularPipeline._component_spec_to_dict(spec) {
+ "type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
+ "subfolder": "subfolder", "variant": None, "revision": None, "type_hint": ("diffusers",
+ "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder",
+ "variant": None, "revision": None,
+ }
+ """
+ if component_spec.default_creation_method != "from_pretrained":
+ return None
+
+ if component_spec.type_hint is not None:
+ lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint)
+ else:
+ lib_name = None
+ cls_name = None
+ load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()}
+ return {
+ "type_hint": (lib_name, cls_name),
+ **load_spec_dict,
+ }
+
+ @staticmethod
+ def _dict_to_component_spec(
+ name: str,
+ spec_dict: Dict[str, Any],
+ ) -> ComponentSpec:
+ """
+ Reconstruct a ComponentSpec from a loading specdict.
+
+ This method converts a dictionary representation back into a ComponentSpec object. The dict should contain:
+ - "type_hint": Tuple[str, str]
+ Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
+ - All loading fields defined by `component_spec.loading_fields()`, typically:
+ - "pretrained_model_name_or_path": Optional[str]
+ The model repository (e.g., "stabilityai/stable-diffusion-xl").
+ - "subfolder": Optional[str]
+ A subfolder within the pretrained_model_name_or_path where this component lives.
+ - "variant": Optional[str]
+ An optional variant identifier for the model.
+ - "revision": Optional[str]
+ A specific git revision (commit hash, tag, or branch).
+ - ... any other loading fields defined on the spec.
+
+ Args:
+ name (str):
+ The name of the component.
+ specdict (Dict[str, Any]):
+ A dictionary containing the component specification data.
+
+ Returns:
+ ComponentSpec: A reconstructed ComponentSpec object.
+
+ Example:
+ >>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
+ None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
+ ComponentSpec(
+ name="unet", type_hint=UNet2DConditionModel, config=None,
+ pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
+ revision=None, default_creation_method="from_pretrained"
+ >>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
+ None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
+ ComponentSpec(
+ name="unet", type_hint=UNet2DConditionModel, config=None,
+ pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
+ revision=None, default_creation_method="from_pretrained"
+ )
+ """
+ # make a shallow copy so we can pop() safely
+ spec_dict = spec_dict.copy()
+ # pull out and resolve the stored type_hint
+ lib_name, cls_name = spec_dict.pop("type_hint")
+ if lib_name is not None and cls_name is not None:
+ type_hint = simple_get_class_obj(lib_name, cls_name)
+ else:
+ type_hint = None
+
+ # re‐assemble the ComponentSpec
+ return ComponentSpec(
+ name=name,
+ type_hint=type_hint,
+ **spec_dict,
+ )
+
+ def set_progress_bar_config(self, **kwargs):
+ for sub_block_name, sub_block in self.blocks.sub_blocks.items():
+ if hasattr(sub_block, "set_progress_bar_config"):
+ sub_block.set_progress_bar_config(**kwargs)
+
+ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
+ """
+ Execute the pipeline by running the pipeline blocks with the given inputs.
+
+ Args:
+ state (`PipelineState`, optional):
+ PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
+ created based on the user inputs and the pipeline blocks's requirement.
+ output (`str` or `List[str]`, optional):
+ Optional specification of what to return:
+ - None: Returns the complete `PipelineState` with all inputs and intermediates (default)
+ - str: Returns a specific intermediate value from the state (e.g. `output="image"`)
+ - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
+ "latents"]`)
+
+
+ Examples:
+ ```python
+ # Get complete pipeline state
+ state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
+ print(state.intermediates) # All intermediate outputs
+
+ # Get specific output
+ image = pipeline(prompt="A beautiful sunset", output="image")
+
+ # Get multiple specific outputs
+ results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
+ image, latents = results["image"], results["latents"]
+
+ # Continue from previous state
+ state = pipeline(prompt="A beautiful sunset")
+ new_state = pipeline(state=state, output="image") # Continue processing
+ ```
+
+ Returns:
+ - If `output` is None: Complete `PipelineState` containing all inputs and intermediates
+ - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
+ - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
+ `output=["image", "latents"]`)
+ """
+ if state is None:
+ state = PipelineState()
+ else:
+ state = deepcopy(state)
+
+ # Make a copy of the input kwargs
+ passed_kwargs = kwargs.copy()
+
+ # Add inputs to state, using defaults if not provided in the kwargs or the state
+ # if same input already in the state, will override it if provided in the kwargs
+ for expected_input_param in self.blocks.inputs:
+ name = expected_input_param.name
+ default = expected_input_param.default
+ kwargs_type = expected_input_param.kwargs_type
+ if name in passed_kwargs:
+ state.set(name, passed_kwargs.pop(name), kwargs_type)
+ elif name not in state.values:
+ state.set(name, default, kwargs_type)
+
+ # Warn about unexpected inputs
+ if len(passed_kwargs) > 0:
+ warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
+ # Run the pipeline
+ with torch.no_grad():
+ try:
+ _, state = self.blocks(self, state)
+ except Exception:
+ error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
+ logger.error(error_msg)
+ raise
+
+ if output is None:
+ return state
+
+ if isinstance(output, str):
+ return state.get(output)
+
+ elif isinstance(output, (list, tuple)):
+ return state.get(output)
+ else:
+ raise ValueError(f"Output '{output}' is not a valid output type")
diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py
new file mode 100644
index 000000000000..aa421a53727b
--- /dev/null
+++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py
@@ -0,0 +1,692 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from collections import OrderedDict
+from dataclasses import dataclass, field, fields
+from typing import Any, Dict, List, Literal, Optional, Type, Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, FrozenDict
+from ..loaders.single_file_utils import _is_single_file_path_or_url
+from ..utils import is_torch_available, logging
+
+
+if is_torch_available():
+ pass
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class InsertableDict(OrderedDict):
+ def insert(self, key, value, index):
+ items = list(self.items())
+
+ # Remove key if it already exists to avoid duplicates
+ items = [(k, v) for k, v in items if k != key]
+
+ # Insert at the specified index
+ items.insert(index, (key, value))
+
+ # Clear and update self
+ self.clear()
+ self.update(items)
+
+ # Return self for method chaining
+ return self
+
+ def __repr__(self):
+ if not self:
+ return "InsertableDict()"
+
+ items = []
+ for i, (key, value) in enumerate(self.items()):
+ if isinstance(value, type):
+ # For classes, show class name and
+ obj_repr = f""
+ else:
+ # For objects (instances) and other types, show class name and module
+ obj_repr = f""
+ items.append(f"{i}: ({repr(key)}, {obj_repr})")
+
+ return "InsertableDict([\n " + ",\n ".join(items) + "\n])"
+
+
+# YiYi TODO:
+# 1. validate the dataclass fields
+# 2. improve the docstring and potentially add a validator for load methods, make sure they are valid inputs to pass to from_pretrained()
+@dataclass
+class ComponentSpec:
+ """Specification for a pipeline component.
+
+ A component can be created in two ways:
+ 1. From scratch using __init__ with a config dict
+ 2. using `from_pretrained`
+
+ Attributes:
+ name: Name of the component
+ type_hint: Type of the component (e.g. UNet2DConditionModel)
+ description: Optional description of the component
+ config: Optional config dict for __init__ creation
+ pretrained_model_name_or_path: Optional pretrained_model_name_or_path path for from_pretrained creation
+ subfolder: Optional subfolder in pretrained_model_name_or_path
+ variant: Optional variant in pretrained_model_name_or_path
+ revision: Optional revision in pretrained_model_name_or_path
+ default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
+ """
+
+ name: Optional[str] = None
+ type_hint: Optional[Type] = None
+ description: Optional[str] = None
+ config: Optional[FrozenDict] = None
+ pretrained_model_name_or_path: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
+ subfolder: Optional[str] = field(default="", metadata={"loading": True})
+ variant: Optional[str] = field(default=None, metadata={"loading": True})
+ revision: Optional[str] = field(default=None, metadata={"loading": True})
+ default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
+
+ # Deprecated
+ repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": False})
+
+ def __post_init__(self):
+ repo_value = self.repo
+ if repo_value is not None and self.pretrained_model_name_or_path is None:
+ object.__setattr__(self, "pretrained_model_name_or_path", repo_value)
+
+ def __hash__(self):
+ """Make ComponentSpec hashable, using load_id as the hash value."""
+ return hash((self.name, self.load_id, self.default_creation_method))
+
+ def __eq__(self, other):
+ """Compare ComponentSpec objects based on name and load_id."""
+ if not isinstance(other, ComponentSpec):
+ return False
+ return (
+ self.name == other.name
+ and self.load_id == other.load_id
+ and self.default_creation_method == other.default_creation_method
+ )
+
+ @classmethod
+ def from_component(cls, name: str, component: Any) -> Any:
+ """Create a ComponentSpec from a Component.
+
+ Currently supports:
+ - Components created with `ComponentSpec.load()` method
+ - Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders)
+
+ Args:
+ name: Name of the component
+ component: Component object to create spec from
+
+ Returns:
+ ComponentSpec object
+
+ Raises:
+ ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin)
+ """
+
+ # Check if component was created with ComponentSpec.load()
+ if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
+ # component has a usable load_id -> from_pretrained, no warning needed
+ default_creation_method = "from_pretrained"
+ else:
+ # Component doesn't have a usable load_id, check if it's a nn.Module
+ if isinstance(component, torch.nn.Module):
+ raise ValueError(
+ "Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method."
+ )
+ # ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config
+ elif isinstance(component, ConfigMixin):
+ # warn if component was not created with `ComponentSpec`
+ if not hasattr(component, "_diffusers_load_id"):
+ logger.warning(
+ "Component was not created using `ComponentSpec`, defaulting to `from_config` creation method"
+ )
+ default_creation_method = "from_config"
+ else:
+ # Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error
+ raise ValueError(
+ f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: "
+ f" - components created with `ComponentSpec.load()` method"
+ f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)."
+ )
+
+ type_hint = component.__class__
+
+ if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
+ config = component.config
+ else:
+ config = None
+ if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
+ load_spec = cls.decode_load_id(component._diffusers_load_id)
+ else:
+ load_spec = {}
+
+ return cls(
+ name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec
+ )
+
+ @classmethod
+ def loading_fields(cls) -> List[str]:
+ """
+ Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
+ """
+ return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
+
+ @property
+ def load_id(self) -> str:
+ """
+ Unique identifier for this spec's pretrained load, composed of
+ pretrained_model_name_or_path|subfolder|variant|revision (no empty segments).
+ """
+ if self.default_creation_method == "from_config":
+ return "null"
+ parts = [getattr(self, k) for k in self.loading_fields()]
+ parts = ["null" if p is None else p for p in parts]
+ return "|".join(p for p in parts if p)
+
+ @classmethod
+ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
+ """
+ Decode a load_id string back into a dictionary of loading fields and values.
+
+ Args:
+ load_id: The load_id string to decode, format: "pretrained_model_name_or_path|subfolder|variant|revision"
+ where None values are represented as "null"
+
+ Returns:
+ Dict mapping loading field names to their values. e.g. {
+ "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder", "variant": "variant",
+ "revision": "revision"
+ } If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
+ component not created with `load` method).
+ """
+
+ # Get all loading fields in order
+ loading_fields = cls.loading_fields()
+ result = dict.fromkeys(loading_fields)
+
+ if load_id == "null":
+ return result
+
+ # Split the load_id
+ parts = load_id.split("|")
+
+ # Map parts to loading fields by position
+ for i, part in enumerate(parts):
+ if i < len(loading_fields):
+ # Convert "null" string back to None
+ result[loading_fields[i]] = None if part == "null" else part
+
+ return result
+
+ # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
+ # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
+ # the config info is lost in the process
+ # remove error check in from_component spec and ModularPipeline.update_components() if we remove support for non configmixin in `create()` method
+ def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
+ """Create component using from_config with config."""
+
+ if self.type_hint is None or not isinstance(self.type_hint, type):
+ raise ValueError("`type_hint` is required when using from_config creation method.")
+
+ config = config or self.config or {}
+
+ if issubclass(self.type_hint, ConfigMixin):
+ component = self.type_hint.from_config(config, **kwargs)
+ else:
+ signature_params = inspect.signature(self.type_hint.__init__).parameters
+ init_kwargs = {}
+ for k, v in config.items():
+ if k in signature_params:
+ init_kwargs[k] = v
+ for k, v in kwargs.items():
+ if k in signature_params:
+ init_kwargs[k] = v
+ component = self.type_hint(**init_kwargs)
+
+ component._diffusers_load_id = "null"
+ if hasattr(component, "config"):
+ self.config = component.config
+
+ return component
+
+ # YiYi TODO: add guard for type of model, if it is supported by from_pretrained
+ def load(self, **kwargs) -> Any:
+ """Load component using from_pretrained."""
+ # select loading fields from kwargs passed from user: e.g. pretrained_model_name_or_path, subfolder, variant, revision, note the list could change
+ passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
+ # merge loading field value in the spec with user passed values to create load_kwargs
+ load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
+
+ pretrained_model_name_or_path = load_kwargs.pop("pretrained_model_name_or_path", None)
+ if pretrained_model_name_or_path is None:
+ raise ValueError(
+ "`pretrained_model_name_or_path` info is required when using `load` method (you can directly set it in `pretrained_model_name_or_path` field of the ComponentSpec or pass it as an argument)"
+ )
+ is_single_file = _is_single_file_path_or_url(pretrained_model_name_or_path)
+ if is_single_file and self.type_hint is None:
+ raise ValueError(
+ f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
+ )
+
+ if self.type_hint is None:
+ try:
+ from diffusers import AutoModel
+
+ component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
+ except Exception as e:
+ raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
+ # update type_hint if AutoModel load successfully
+ self.type_hint = component.__class__
+ else:
+ # determine load method
+ load_method = (
+ getattr(self.type_hint, "from_single_file")
+ if is_single_file
+ else getattr(self.type_hint, "from_pretrained")
+ )
+
+ try:
+ component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
+ except Exception as e:
+ raise ValueError(f"Unable to load {self.name} using load method: {e}")
+
+ self.pretrained_model_name_or_path = pretrained_model_name_or_path
+ for k, v in load_kwargs.items():
+ setattr(self, k, v)
+ component._diffusers_load_id = self.load_id
+
+ return component
+
+
+@dataclass
+class ConfigSpec:
+ """Specification for a pipeline configuration parameter."""
+
+ name: str
+ default: Any
+ description: Optional[str] = None
+
+
+# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
+# however some fields are not relevant for intermediate_inputs
+# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
+# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
+# -> should we use different class for inputs and intermediate_inputs?
+@dataclass
+class InputParam:
+ """Specification for an input parameter."""
+
+ name: str = None
+ type_hint: Any = None
+ default: Any = None
+ required: bool = False
+ description: str = ""
+ kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
+
+ def __repr__(self):
+ return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
+
+
+@dataclass
+class OutputParam:
+ """Specification for an output parameter."""
+
+ name: str
+ type_hint: Any = None
+ description: str = ""
+ kwargs_type: str = None # YiYi notes: remove this feature (maybe)
+
+ def __repr__(self):
+ return (
+ f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
+ )
+
+
+def format_inputs_short(inputs):
+ """
+ Format input parameters into a string representation, with required params first followed by optional ones.
+
+ Args:
+ inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params
+
+ Returns:
+ str: Formatted string of input parameters
+
+ Example:
+ >>> inputs = [ ... InputParam(name="prompt", required=True), ... InputParam(name="image", required=True), ...
+ InputParam(name="guidance_scale", required=False, default=7.5), ... InputParam(name="num_inference_steps",
+ required=False, default=50) ... ] >>> format_inputs_short(inputs) 'prompt, image, guidance_scale=7.5,
+ num_inference_steps=50'
+ """
+ required_inputs = [param for param in inputs if param.required]
+ optional_inputs = [param for param in inputs if not param.required]
+
+ required_str = ", ".join(param.name for param in required_inputs)
+ optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
+
+ inputs_str = required_str
+ if optional_str:
+ inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
+
+ return inputs_str
+
+
+def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs):
+ """
+ Formats intermediate inputs and outputs of a block into a string representation.
+
+ Args:
+ intermediate_inputs: List of intermediate input parameters
+ required_intermediate_inputs: List of required intermediate input names
+ intermediate_outputs: List of intermediate output parameters
+
+ Returns:
+ str: Formatted string like:
+ Intermediates:
+ - inputs: Required(latents), dtype
+ - modified: latents # variables that appear in both inputs and outputs
+ - outputs: images # new outputs only
+ """
+ # Handle inputs
+ input_parts = []
+ for inp in intermediate_inputs:
+ if inp.name in required_intermediate_inputs:
+ input_parts.append(f"Required({inp.name})")
+ else:
+ if inp.name is None and inp.kwargs_type is not None:
+ inp_name = "*_" + inp.kwargs_type
+ else:
+ inp_name = inp.name
+ input_parts.append(inp_name)
+
+ # Handle modified variables (appear in both inputs and outputs)
+ inputs_set = {inp.name for inp in intermediate_inputs}
+ modified_parts = []
+ new_output_parts = []
+
+ for out in intermediate_outputs:
+ if out.name in inputs_set:
+ modified_parts.append(out.name)
+ else:
+ new_output_parts.append(out.name)
+
+ result = []
+ if input_parts:
+ result.append(f" - inputs: {', '.join(input_parts)}")
+ if modified_parts:
+ result.append(f" - modified: {', '.join(modified_parts)}")
+ if new_output_parts:
+ result.append(f" - outputs: {', '.join(new_output_parts)}")
+
+ return "\n".join(result) if result else " (none)"
+
+
+def format_params(params, header="Args", indent_level=4, max_line_length=115):
+ """Format a list of InputParam or OutputParam objects into a readable string representation.
+
+ Args:
+ params: List of InputParam or OutputParam objects to format
+ header: Header text to use (e.g. "Args" or "Returns")
+ indent_level: Number of spaces to indent each parameter line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+
+ Returns:
+ A formatted string representing all parameters
+ """
+ if not params:
+ return ""
+
+ base_indent = " " * indent_level
+ param_indent = " " * (indent_level + 4)
+ desc_indent = " " * (indent_level + 8)
+ formatted_params = []
+
+ def get_type_str(type_hint):
+ if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
+ types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
+ return f"Union[{', '.join(types)}]"
+ return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
+
+ def wrap_text(text, indent, max_length):
+ """Wrap text while preserving markdown links and maintaining indentation."""
+ words = text.split()
+ lines = []
+ current_line = []
+ current_length = 0
+
+ for word in words:
+ word_length = len(word) + (1 if current_line else 0)
+
+ if current_line and current_length + word_length > max_length:
+ lines.append(" ".join(current_line))
+ current_line = [word]
+ current_length = len(word)
+ else:
+ current_line.append(word)
+ current_length += word_length
+
+ if current_line:
+ lines.append(" ".join(current_line))
+
+ return f"\n{indent}".join(lines)
+
+ # Add the header
+ formatted_params.append(f"{base_indent}{header}:")
+
+ for param in params:
+ # Format parameter name and type
+ type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
+ # YiYi Notes: remove this line if we remove kwargs_type
+ name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
+ param_str = f"{param_indent}{name} (`{type_str}`"
+
+ # Add optional tag and default value if parameter is an InputParam and optional
+ if hasattr(param, "required"):
+ if not param.required:
+ param_str += ", *optional*"
+ if param.default is not None:
+ param_str += f", defaults to {param.default}"
+ param_str += "):"
+
+ # Add description on a new line with additional indentation and wrapping
+ if param.description:
+ desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
+ wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
+ param_str += f"\n{desc_indent}{wrapped_desc}"
+
+ formatted_params.append(param_str)
+
+ return "\n\n".join(formatted_params)
+
+
+def format_input_params(input_params, indent_level=4, max_line_length=115):
+ """Format a list of InputParam objects into a readable string representation.
+
+ Args:
+ input_params: List of InputParam objects to format
+ indent_level: Number of spaces to indent each parameter line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+
+ Returns:
+ A formatted string representing all input parameters
+ """
+ return format_params(input_params, "Inputs", indent_level, max_line_length)
+
+
+def format_output_params(output_params, indent_level=4, max_line_length=115):
+ """Format a list of OutputParam objects into a readable string representation.
+
+ Args:
+ output_params: List of OutputParam objects to format
+ indent_level: Number of spaces to indent each parameter line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+
+ Returns:
+ A formatted string representing all output parameters
+ """
+ return format_params(output_params, "Outputs", indent_level, max_line_length)
+
+
+def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
+ """Format a list of ComponentSpec objects into a readable string representation.
+
+ Args:
+ components: List of ComponentSpec objects to format
+ indent_level: Number of spaces to indent each component line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+ add_empty_lines: Whether to add empty lines between components (default: True)
+
+ Returns:
+ A formatted string representing all components
+ """
+ if not components:
+ return ""
+
+ base_indent = " " * indent_level
+ component_indent = " " * (indent_level + 4)
+ formatted_components = []
+
+ # Add the header
+ formatted_components.append(f"{base_indent}Components:")
+ if add_empty_lines:
+ formatted_components.append("")
+
+ # Add each component with optional empty lines between them
+ for i, component in enumerate(components):
+ # Get type name, handling special cases
+ type_name = (
+ component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
+ )
+
+ component_desc = f"{component_indent}{component.name} (`{type_name}`)"
+ if component.description:
+ component_desc += f": {component.description}"
+
+ # Get the loading fields dynamically
+ loading_field_values = []
+ for field_name in component.loading_fields():
+ field_value = getattr(component, field_name)
+ if field_value is not None:
+ loading_field_values.append(f"{field_name}={field_value}")
+
+ # Add loading field information if available
+ if loading_field_values:
+ component_desc += f" [{', '.join(loading_field_values)}]"
+
+ formatted_components.append(component_desc)
+
+ # Add an empty line after each component except the last one
+ if add_empty_lines and i < len(components) - 1:
+ formatted_components.append("")
+
+ return "\n".join(formatted_components)
+
+
+def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True):
+ """Format a list of ConfigSpec objects into a readable string representation.
+
+ Args:
+ configs: List of ConfigSpec objects to format
+ indent_level: Number of spaces to indent each config line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+ add_empty_lines: Whether to add empty lines between configs (default: True)
+
+ Returns:
+ A formatted string representing all configs
+ """
+ if not configs:
+ return ""
+
+ base_indent = " " * indent_level
+ config_indent = " " * (indent_level + 4)
+ formatted_configs = []
+
+ # Add the header
+ formatted_configs.append(f"{base_indent}Configs:")
+ if add_empty_lines:
+ formatted_configs.append("")
+
+ # Add each config with optional empty lines between them
+ for i, config in enumerate(configs):
+ config_desc = f"{config_indent}{config.name} (default: {config.default})"
+ if config.description:
+ config_desc += f": {config.description}"
+ formatted_configs.append(config_desc)
+
+ # Add an empty line after each config except the last one
+ if add_empty_lines and i < len(configs) - 1:
+ formatted_configs.append("")
+
+ return "\n".join(formatted_configs)
+
+
+def make_doc_string(
+ inputs,
+ outputs,
+ description="",
+ class_name=None,
+ expected_components=None,
+ expected_configs=None,
+):
+ """
+ Generates a formatted documentation string describing the pipeline block's parameters and structure.
+
+ Args:
+ inputs: List of input parameters
+ intermediate_inputs: List of intermediate input parameters
+ outputs: List of output parameters
+ description (str, *optional*): Description of the block
+ class_name (str, *optional*): Name of the class to include in the documentation
+ expected_components (List[ComponentSpec], *optional*): List of expected components
+ expected_configs (List[ConfigSpec], *optional*): List of expected configurations
+
+ Returns:
+ str: A formatted string containing information about components, configs, call parameters,
+ intermediate inputs/outputs, and final outputs.
+ """
+ output = ""
+
+ # Add class name if provided
+ if class_name:
+ output += f"class {class_name}\n\n"
+
+ # Add description
+ if description:
+ desc_lines = description.strip().split("\n")
+ aligned_desc = "\n".join(" " + line for line in desc_lines)
+ output += aligned_desc + "\n\n"
+
+ # Add components section if provided
+ if expected_components and len(expected_components) > 0:
+ components_str = format_components(expected_components, indent_level=2)
+ output += components_str + "\n\n"
+
+ # Add configs section if provided
+ if expected_configs and len(expected_configs) > 0:
+ configs_str = format_configs(expected_configs, indent_level=2)
+ output += configs_str + "\n\n"
+
+ # Add inputs section
+ output += format_input_params(inputs, indent_level=2)
+
+ # Add outputs section
+ output += "\n\n"
+ output += format_output_params(outputs, indent_level=2)
+
+ return output
diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py
new file mode 100644
index 000000000000..f7ee1dd3097b
--- /dev/null
+++ b/src/diffusers/modular_pipelines/node_utils.py
@@ -0,0 +1,661 @@
+import json
+import logging
+import os
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ..configuration_utils import ConfigMixin
+from ..image_processor import PipelineImageInput
+from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
+from .modular_pipeline_utils import InputParam
+
+
+logger = logging.getLogger(__name__)
+
+# YiYi Notes: this is actually for SDXL, put it here for now
+SDXL_INPUTS_SCHEMA = {
+ "prompt": InputParam(
+ "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
+ ),
+ "prompt_2": InputParam(
+ "prompt_2",
+ type_hint=Union[str, List[str]],
+ description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
+ ),
+ "negative_prompt": InputParam(
+ "negative_prompt",
+ type_hint=Union[str, List[str]],
+ description="The prompt or prompts not to guide the image generation",
+ ),
+ "negative_prompt_2": InputParam(
+ "negative_prompt_2",
+ type_hint=Union[str, List[str]],
+ description="The negative prompt or prompts for text_encoder_2",
+ ),
+ "cross_attention_kwargs": InputParam(
+ "cross_attention_kwargs",
+ type_hint=Optional[dict],
+ description="Kwargs dictionary passed to the AttentionProcessor",
+ ),
+ "clip_skip": InputParam(
+ "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
+ ),
+ "image": InputParam(
+ "image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="The image(s) to modify for img2img or inpainting",
+ ),
+ "mask_image": InputParam(
+ "mask_image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="Mask image for inpainting, white pixels will be repainted",
+ ),
+ "generator": InputParam(
+ "generator",
+ type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
+ description="Generator(s) for deterministic generation",
+ ),
+ "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
+ "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
+ "num_images_per_prompt": InputParam(
+ "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
+ ),
+ "num_inference_steps": InputParam(
+ "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
+ ),
+ "timesteps": InputParam(
+ "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
+ ),
+ "sigmas": InputParam(
+ "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
+ ),
+ "denoising_end": InputParam(
+ "denoising_end",
+ type_hint=Optional[float],
+ description="Fraction of denoising process to complete before termination",
+ ),
+ # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
+ "strength": InputParam(
+ "strength", type_hint=float, default=0.3, description="How much to transform the reference image"
+ ),
+ "denoising_start": InputParam(
+ "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
+ ),
+ "latents": InputParam(
+ "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
+ ),
+ "padding_mask_crop": InputParam(
+ "padding_mask_crop",
+ type_hint=Optional[Tuple[int, int]],
+ description="Size of margin in crop for image and mask",
+ ),
+ "original_size": InputParam(
+ "original_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Original size of the image for SDXL's micro-conditioning",
+ ),
+ "target_size": InputParam(
+ "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
+ ),
+ "negative_original_size": InputParam(
+ "negative_original_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Negative conditioning based on image resolution",
+ ),
+ "negative_target_size": InputParam(
+ "negative_target_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Negative conditioning based on target resolution",
+ ),
+ "crops_coords_top_left": InputParam(
+ "crops_coords_top_left",
+ type_hint=Tuple[int, int],
+ default=(0, 0),
+ description="Top-left coordinates for SDXL's micro-conditioning",
+ ),
+ "negative_crops_coords_top_left": InputParam(
+ "negative_crops_coords_top_left",
+ type_hint=Tuple[int, int],
+ default=(0, 0),
+ description="Negative conditioning crop coordinates",
+ ),
+ "aesthetic_score": InputParam(
+ "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
+ ),
+ "negative_aesthetic_score": InputParam(
+ "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
+ ),
+ "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
+ "output_type": InputParam(
+ "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
+ ),
+ "ip_adapter_image": InputParam(
+ "ip_adapter_image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="Image(s) to be used as IP adapter",
+ ),
+ "control_image": InputParam(
+ "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
+ ),
+ "control_guidance_start": InputParam(
+ "control_guidance_start",
+ type_hint=Union[float, List[float]],
+ default=0.0,
+ description="When ControlNet starts applying",
+ ),
+ "control_guidance_end": InputParam(
+ "control_guidance_end",
+ type_hint=Union[float, List[float]],
+ default=1.0,
+ description="When ControlNet stops applying",
+ ),
+ "controlnet_conditioning_scale": InputParam(
+ "controlnet_conditioning_scale",
+ type_hint=Union[float, List[float]],
+ default=1.0,
+ description="Scale factor for ControlNet outputs",
+ ),
+ "guess_mode": InputParam(
+ "guess_mode",
+ type_hint=bool,
+ default=False,
+ description="Enables ControlNet encoder to recognize input without prompts",
+ ),
+ "control_mode": InputParam(
+ "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
+ ),
+}
+
+SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
+ "prompt_embeds": InputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ required=True,
+ description="Text embeddings used to guide image generation",
+ ),
+ "negative_prompt_embeds": InputParam(
+ "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
+ ),
+ "pooled_prompt_embeds": InputParam(
+ "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
+ ),
+ "negative_pooled_prompt_embeds": InputParam(
+ "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
+ ),
+ "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
+ "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
+ "preprocess_kwargs": InputParam(
+ "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
+ ),
+ "latents": InputParam(
+ "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
+ ),
+ "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
+ "num_inference_steps": InputParam(
+ "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
+ ),
+ "latent_timestep": InputParam(
+ "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
+ ),
+ "image_latents": InputParam(
+ "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
+ ),
+ "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
+ "masked_image_latents": InputParam(
+ "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
+ ),
+ "add_time_ids": InputParam(
+ "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
+ ),
+ "negative_add_time_ids": InputParam(
+ "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
+ ),
+ "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
+ "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
+ "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
+ "ip_adapter_embeds": InputParam(
+ "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
+ ),
+ "negative_ip_adapter_embeds": InputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Negative image embeddings for IP-Adapter",
+ ),
+ "images": InputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ required=True,
+ description="Generated images",
+ ),
+}
+
+SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
+
+
+DEFAULT_PARAM_MAPS = {
+ "prompt": {
+ "label": "Prompt",
+ "type": "string",
+ "default": "a bear sitting in a chair drinking a milkshake",
+ "display": "textarea",
+ },
+ "negative_prompt": {
+ "label": "Negative Prompt",
+ "type": "string",
+ "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
+ "display": "textarea",
+ },
+ "num_inference_steps": {
+ "label": "Steps",
+ "type": "int",
+ "default": 25,
+ "min": 1,
+ "max": 1000,
+ },
+ "seed": {
+ "label": "Seed",
+ "type": "int",
+ "default": 0,
+ "min": 0,
+ "display": "random",
+ },
+ "width": {
+ "label": "Width",
+ "type": "int",
+ "display": "text",
+ "default": 1024,
+ "min": 8,
+ "max": 8192,
+ "step": 8,
+ "group": "dimensions",
+ },
+ "height": {
+ "label": "Height",
+ "type": "int",
+ "display": "text",
+ "default": 1024,
+ "min": 8,
+ "max": 8192,
+ "step": 8,
+ "group": "dimensions",
+ },
+ "images": {
+ "label": "Images",
+ "type": "image",
+ "display": "output",
+ },
+ "image": {
+ "label": "Image",
+ "type": "image",
+ "display": "input",
+ },
+}
+
+DEFAULT_TYPE_MAPS = {
+ "int": {
+ "type": "int",
+ "default": 0,
+ "min": 0,
+ },
+ "float": {
+ "type": "float",
+ "default": 0.0,
+ "min": 0.0,
+ },
+ "str": {
+ "type": "string",
+ "default": "",
+ },
+ "bool": {
+ "type": "boolean",
+ "default": False,
+ },
+ "image": {
+ "type": "image",
+ },
+}
+
+DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
+DEFAULT_CATEGORY = "Modular Diffusers"
+DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
+DEFAULT_PARAMS_GROUPS_KEYS = {
+ "text_encoders": ["text_encoder", "tokenizer"],
+ "ip_adapter_embeds": ["ip_adapter_embeds"],
+ "prompt_embeddings": ["prompt_embeds"],
+}
+
+
+def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
+ """
+ Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
+ "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
+ """
+ if name is None:
+ return None
+ for group_name, group_keys in group_params_keys.items():
+ for group_key in group_keys:
+ if group_key in name:
+ return group_name
+ return None
+
+
+class ModularNode(ConfigMixin):
+ """
+ A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
+ around a ModularPipelineBlocks object.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ config_name = "node_config.json"
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ trust_remote_code: Optional[bool] = None,
+ **kwargs,
+ ):
+ blocks = ModularPipelineBlocks.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ return cls(blocks, **kwargs)
+
+ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
+ self.blocks = blocks
+
+ if label is None:
+ label = self.blocks.__class__.__name__
+ # blocks param name -> mellon param name
+ self.name_mapping = {}
+
+ input_params = {}
+ # pass or create a default param dict for each input
+ # e.g. for prompt,
+ # prompt = {
+ # "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
+ # "label": "Prompt",
+ # "type": "string",
+ # "default": "a bear sitting in a chair drinking a milkshake",
+ # "display": "textarea"}
+ # if type is not specified, it'll be a "custom" param of its own type
+ # e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
+ # it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
+ # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
+ inputs = self.blocks.inputs + self.blocks.intermediate_inputs
+ for inp in inputs:
+ param = kwargs.pop(inp.name, None)
+ if param:
+ # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
+ input_params[inp.name] = param
+ mellon_name = param.pop("name", inp.name)
+ if mellon_name != inp.name:
+ self.name_mapping[inp.name] = mellon_name
+ continue
+
+ if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
+ continue
+
+ if inp.name in DEFAULT_PARAM_MAPS:
+ # first check if it's in the default param map, if so, directly use that
+ param = DEFAULT_PARAM_MAPS[inp.name].copy()
+ elif get_group_name(inp.name):
+ param = get_group_name(inp.name)
+ if inp.name not in self.name_mapping:
+ self.name_mapping[inp.name] = param
+ else:
+ # if not, check if it's in the SDXL input schema, if so,
+ # 1. use the type hint to determine the type
+ # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
+ if inp.type_hint is not None:
+ type_str = str(inp.type_hint).lower()
+ else:
+ inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
+ type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
+ for type_key, type_param in DEFAULT_TYPE_MAPS.items():
+ if type_key in type_str:
+ param = type_param.copy()
+ param["label"] = inp.name
+ param["display"] = "input"
+ break
+ else:
+ param = inp.name
+ # add the param dict to the inp_params dict
+ input_params[inp.name] = param
+
+ component_params = {}
+ for comp in self.blocks.expected_components:
+ param = kwargs.pop(comp.name, None)
+ if param:
+ component_params[comp.name] = param
+ mellon_name = param.pop("name", comp.name)
+ if mellon_name != comp.name:
+ self.name_mapping[comp.name] = mellon_name
+ continue
+
+ to_exclude = False
+ for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
+ if exclude_key in comp.name:
+ to_exclude = True
+ break
+ if to_exclude:
+ continue
+
+ if get_group_name(comp.name):
+ param = get_group_name(comp.name)
+ if comp.name not in self.name_mapping:
+ self.name_mapping[comp.name] = param
+ elif comp.name in DEFAULT_MODEL_KEYS:
+ param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
+ else:
+ param = comp.name
+ # add the param dict to the model_params dict
+ component_params[comp.name] = param
+
+ output_params = {}
+ if isinstance(self.blocks, SequentialPipelineBlocks):
+ last_block_name = list(self.blocks.sub_blocks.keys())[-1]
+ outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
+ else:
+ outputs = self.blocks.intermediate_outputs
+
+ for out in outputs:
+ param = kwargs.pop(out.name, None)
+ if param:
+ output_params[out.name] = param
+ mellon_name = param.pop("name", out.name)
+ if mellon_name != out.name:
+ self.name_mapping[out.name] = mellon_name
+ continue
+
+ if out.name in DEFAULT_PARAM_MAPS:
+ param = DEFAULT_PARAM_MAPS[out.name].copy()
+ param["display"] = "output"
+ else:
+ group_name = get_group_name(out.name)
+ if group_name:
+ param = group_name
+ if out.name not in self.name_mapping:
+ self.name_mapping[out.name] = param
+ else:
+ param = out.name
+ # add the param dict to the outputs dict
+ output_params[out.name] = param
+
+ if len(kwargs) > 0:
+ logger.warning(f"Unused kwargs: {kwargs}")
+
+ register_dict = {
+ "category": category,
+ "label": label,
+ "input_params": input_params,
+ "component_params": component_params,
+ "output_params": output_params,
+ "name_mapping": self.name_mapping,
+ }
+ self.register_to_config(**register_dict)
+
+ def setup(self, components_manager, collection=None):
+ self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
+ self._components_manager = components_manager
+
+ @property
+ def mellon_config(self):
+ return self._convert_to_mellon_config()
+
+ def _convert_to_mellon_config(self):
+ node = {}
+ node["label"] = self.config.label
+ node["category"] = self.config.category
+
+ node_param = {}
+ for inp_name, inp_param in self.config.input_params.items():
+ if inp_name in self.name_mapping:
+ mellon_name = self.name_mapping[inp_name]
+ else:
+ mellon_name = inp_name
+ if isinstance(inp_param, str):
+ param = {
+ "label": inp_param,
+ "type": inp_param,
+ "display": "input",
+ }
+ else:
+ param = inp_param
+
+ if mellon_name not in node_param:
+ node_param[mellon_name] = param
+ else:
+ logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
+
+ for comp_name, comp_param in self.config.component_params.items():
+ if comp_name in self.name_mapping:
+ mellon_name = self.name_mapping[comp_name]
+ else:
+ mellon_name = comp_name
+ if isinstance(comp_param, str):
+ param = {
+ "label": comp_param,
+ "type": comp_param,
+ "display": "input",
+ }
+ else:
+ param = comp_param
+
+ if mellon_name not in node_param:
+ node_param[mellon_name] = param
+ else:
+ logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
+
+ for out_name, out_param in self.config.output_params.items():
+ if out_name in self.name_mapping:
+ mellon_name = self.name_mapping[out_name]
+ else:
+ mellon_name = out_name
+ if isinstance(out_param, str):
+ param = {
+ "label": out_param,
+ "type": out_param,
+ "display": "output",
+ }
+ else:
+ param = out_param
+
+ if mellon_name not in node_param:
+ node_param[mellon_name] = param
+ else:
+ logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
+ node["params"] = node_param
+ return node
+
+ def save_mellon_config(self, file_path):
+ """
+ Save the Mellon configuration to a JSON file.
+
+ Args:
+ file_path (str or Path): Path where the JSON file will be saved
+
+ Returns:
+ Path: Path to the saved config file
+ """
+ file_path = Path(file_path)
+
+ # Create directory if it doesn't exist
+ os.makedirs(file_path.parent, exist_ok=True)
+
+ # Create a combined dictionary with module definition and name mapping
+ config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
+
+ # Save the config to file
+ with open(file_path, "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2)
+
+ logger.info(f"Mellon config and name mapping saved to {file_path}")
+
+ return file_path
+
+ @classmethod
+ def load_mellon_config(cls, file_path):
+ """
+ Load a Mellon configuration from a JSON file.
+
+ Args:
+ file_path (str or Path): Path to the JSON file containing Mellon config
+
+ Returns:
+ dict: The loaded combined configuration containing 'module' and 'name_mapping'
+ """
+ file_path = Path(file_path)
+
+ if not file_path.exists():
+ raise FileNotFoundError(f"Config file not found: {file_path}")
+
+ with open(file_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+
+ logger.info(f"Mellon config loaded from {file_path}")
+
+ return config
+
+ def process_inputs(self, **kwargs):
+ params_components = {}
+ for comp_name, comp_param in self.config.component_params.items():
+ logger.debug(f"component: {comp_name}")
+ mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
+ if mellon_comp_name in kwargs:
+ if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
+ comp = kwargs[mellon_comp_name].pop(comp_name)
+ else:
+ comp = kwargs.pop(mellon_comp_name)
+ if comp:
+ params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
+
+ params_run = {}
+ for inp_name, inp_param in self.config.input_params.items():
+ logger.debug(f"input: {inp_name}")
+ mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
+ if mellon_inp_name in kwargs:
+ if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
+ inp = kwargs[mellon_inp_name].pop(inp_name)
+ else:
+ inp = kwargs.pop(mellon_inp_name)
+ if inp is not None:
+ params_run[inp_name] = inp
+
+ return_output_names = list(self.config.output_params.keys())
+
+ return params_components, params_run, return_output_names
+
+ def execute(self, **kwargs):
+ params_components, params_run, return_output_names = self.process_inputs(**kwargs)
+
+ self.pipeline.update_components(**params_components)
+ output = self.pipeline(**params_run, output=return_output_names)
+ return output
diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py
new file mode 100644
index 000000000000..ae4ec4799fbc
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py
@@ -0,0 +1,89 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["encoders"] = ["QwenImageTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "AUTO_BLOCKS",
+ "CONTROLNET_BLOCKS",
+ "EDIT_AUTO_BLOCKS",
+ "EDIT_BLOCKS",
+ "EDIT_INPAINT_BLOCKS",
+ "EDIT_PLUS_AUTO_BLOCKS",
+ "EDIT_PLUS_BLOCKS",
+ "IMAGE2IMAGE_BLOCKS",
+ "INPAINT_BLOCKS",
+ "TEXT2IMAGE_BLOCKS",
+ "QwenImageAutoBlocks",
+ "QwenImageEditAutoBlocks",
+ "QwenImageEditPlusAutoBlocks",
+ ]
+ _import_structure["modular_pipeline"] = [
+ "QwenImageEditModularPipeline",
+ "QwenImageEditPlusModularPipeline",
+ "QwenImageModularPipeline",
+ ]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .encoders import (
+ QwenImageTextEncoderStep,
+ )
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ AUTO_BLOCKS,
+ CONTROLNET_BLOCKS,
+ EDIT_AUTO_BLOCKS,
+ EDIT_BLOCKS,
+ EDIT_INPAINT_BLOCKS,
+ EDIT_PLUS_AUTO_BLOCKS,
+ EDIT_PLUS_BLOCKS,
+ IMAGE2IMAGE_BLOCKS,
+ INPAINT_BLOCKS,
+ TEXT2IMAGE_BLOCKS,
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditPlusAutoBlocks,
+ )
+ from .modular_pipeline import (
+ QwenImageEditModularPipeline,
+ QwenImageEditPlusModularPipeline,
+ QwenImageModularPipeline,
+ )
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
new file mode 100644
index 000000000000..0e470332c6f4
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
@@ -0,0 +1,725 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils.torch_utils import randn_tensor, unwrap_module
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+def get_timesteps(scheduler, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = scheduler.timesteps[t_start * scheduler.order :]
+ if hasattr(scheduler, "set_begin_index"):
+ scheduler.set_begin_index(t_start * scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+
+# Prepare Latents steps
+
+
+class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Prepare initial random noise for the generation process"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("latents"),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="generator"),
+ InputParam(
+ name="batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ name="dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of the model inputs, can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="latents",
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ height=block_state.height,
+ width=block_state.width,
+ vae_scale_factor=components.vae_scale_factor,
+ )
+
+ device = components._execution_device
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+
+ # we can update the height and width here since it's used to generate the initial
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+
+ shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width)
+ if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if block_state.latents is None:
+ block_state.latents = randn_tensor(
+ shape, generator=block_state.generator, device=device, dtype=block_state.dtype
+ )
+ block_state.latents = components.pachifier.pack_latents(block_state.latents)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial random noised, can be generated in prepare latent step.",
+ ),
+ InputParam(
+ name="image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
+ ),
+ InputParam(
+ name="timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="initial_noise",
+ type_hint=torch.Tensor,
+ description="The initial random noised used for inpainting denoising.",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(image_latents, latents):
+ if image_latents.shape[0] != latents.shape[0]:
+ raise ValueError(
+ f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
+ )
+
+ if image_latents.ndim != 3:
+ raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ image_latents=block_state.image_latents,
+ latents=block_state.latents,
+ )
+
+ # prepare latent timestep
+ latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
+
+ # make copy of initial_noise
+ block_state.initial_noise = block_state.latents
+
+ # scale noise
+ block_state.latents = components.scheduler.scale_noise(
+ block_state.image_latents, latent_timestep, block_state.latents
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name="processed_mask_image",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The processed mask to use for the inpainting process.",
+ ),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="dtype", required=True),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+
+ height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+
+ block_state.mask = torch.nn.functional.interpolate(
+ block_state.processed_mask_image,
+ size=(height_latents, width_latents),
+ )
+
+ block_state.mask = block_state.mask.unsqueeze(2)
+ block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1)
+ block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype)
+
+ block_state.mask = components.pachifier.pack_latents(block_state.mask)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# Set Timesteps steps
+
+
+class QwenImageSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_inference_steps", default=50),
+ InputParam(name="sigmas"),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process, used to calculate the image sequence length.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ sigmas = (
+ np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
+ if block_state.sigmas is None
+ else block_state.sigmas
+ )
+
+ mu = calculate_shift(
+ image_seq_len=block_state.latents.shape[1],
+ base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
+ max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
+ base_shift=components.scheduler.config.get("base_shift", 0.5),
+ max_shift=components.scheduler.config.get("max_shift", 1.15),
+ )
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ device=device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ components.scheduler.set_begin_index(0)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_inference_steps", default=50),
+ InputParam(name="sigmas"),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process, used to calculate the image sequence length.",
+ ),
+ InputParam(name="strength", default=0.9),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="timesteps",
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ sigmas = (
+ np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
+ if block_state.sigmas is None
+ else block_state.sigmas
+ )
+
+ mu = calculate_shift(
+ image_seq_len=block_state.latents.shape[1],
+ base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
+ max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
+ base_shift=components.scheduler.config.get("base_shift", 0.5),
+ max_shift=components.scheduler.config.get("max_shift", 1.15),
+ )
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ device=device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ block_state.timesteps, block_state.num_inference_steps = get_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ strength=block_state.strength,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# other inputs for denoiser
+
+## RoPE inputs for denoiser
+
+
+class QwenImageRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds_mask"),
+ InputParam(name="negative_prompt_embeds_mask"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="img_shapes",
+ type_hint=List[List[Tuple[int, int, int]]],
+ description="The shapes of the images latents, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="negative_txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.img_shapes = [
+ [
+ (
+ 1,
+ block_state.height // components.vae_scale_factor // 2,
+ block_state.width // components.vae_scale_factor // 2,
+ )
+ ]
+ ] * block_state.batch_size
+ block_state.txt_seq_lens = (
+ block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
+ )
+ block_state.negative_txt_seq_lens = (
+ block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
+ if block_state.negative_prompt_embeds_mask is not None
+ else None
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="batch_size", required=True),
+ InputParam(name="image_height", required=True),
+ InputParam(name="image_width", required=True),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds_mask"),
+ InputParam(name="negative_prompt_embeds_mask"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="img_shapes",
+ type_hint=List[List[Tuple[int, int, int]]],
+ description="The shapes of the images latents, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="negative_txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # for edit, image size can be different from the target size (height/width)
+
+ block_state.img_shapes = [
+ [
+ (
+ 1,
+ block_state.height // components.vae_scale_factor // 2,
+ block_state.width // components.vae_scale_factor // 2,
+ ),
+ (
+ 1,
+ block_state.image_height // components.vae_scale_factor // 2,
+ block_state.image_width // components.vae_scale_factor // 2,
+ ),
+ ]
+ ] * block_state.batch_size
+
+ block_state.txt_seq_lens = (
+ block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
+ )
+ block_state.negative_txt_seq_lens = (
+ block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
+ if block_state.negative_prompt_embeds_mask is not None
+ else None
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+## ControlNet inputs for denoiser
+class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("control_guidance_start", default=0.0),
+ InputParam("control_guidance_end", default=1.0),
+ InputParam("controlnet_conditioning_scale", default=1.0),
+ InputParam("control_image_latents", required=True),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ controlnet = unwrap_module(components.controlnet)
+
+ # control_guidance_start/control_guidance_end (align format)
+ if not isinstance(block_state.control_guidance_start, list) and isinstance(
+ block_state.control_guidance_end, list
+ ):
+ block_state.control_guidance_start = len(block_state.control_guidance_end) * [
+ block_state.control_guidance_start
+ ]
+ elif not isinstance(block_state.control_guidance_end, list) and isinstance(
+ block_state.control_guidance_start, list
+ ):
+ block_state.control_guidance_end = len(block_state.control_guidance_start) * [
+ block_state.control_guidance_end
+ ]
+ elif not isinstance(block_state.control_guidance_start, list) and not isinstance(
+ block_state.control_guidance_end, list
+ ):
+ mult = (
+ len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1
+ )
+ block_state.control_guidance_start, block_state.control_guidance_end = (
+ mult * [block_state.control_guidance_start],
+ mult * [block_state.control_guidance_end],
+ )
+
+ # controlnet_conditioning_scale (align format)
+ if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance(
+ block_state.controlnet_conditioning_scale, float
+ ):
+ block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult
+
+ # controlnet_keep
+ block_state.controlnet_keep = []
+ for i in range(len(block_state.timesteps)):
+ keeps = [
+ 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
+ for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
+ ]
+ block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py
new file mode 100644
index 000000000000..26417162deee
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py
@@ -0,0 +1,204 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import InpaintProcessor, VaeImageProcessor
+from ...models import AutoencoderKLQwenImage
+from ...utils import logging
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+logger = logging.get_logger(__name__)
+
+
+class QwenImageDecoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the latents to images"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to decode, can be generated in the denoise step",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
+ vae_scale_factor = components.vae_scale_factor
+ block_state.latents = components.pachifier.unpack_latents(
+ block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
+ )
+ block_state.latents = block_state.latents.to(components.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(components.vae.config.latents_mean)
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
+ .to(block_state.latents.device, block_state.latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
+ 1, components.vae.config.z_dim, 1, 1, 1
+ ).to(block_state.latents.device, block_state.latents.dtype)
+ block_state.latents = block_state.latents / latents_std + latents_mean
+ block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0]
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "postprocess the generated image"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("images", required=True, description="the generated image from decoders step"),
+ InputParam(
+ name="output_type",
+ default="pil",
+ type_hint=str,
+ description="The type of the output images, can be 'pil', 'np', 'pt'",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(output_type):
+ if output_type not in ["pil", "np", "pt"]:
+ raise ValueError(f"Invalid output_type: {output_type}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.output_type)
+
+ block_state.images = components.image_processor.postprocess(
+ image=block_state.images,
+ output_type=block_state.output_type,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "postprocess the generated image, optional apply the mask overally to the original image.."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_mask_processor",
+ InpaintProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("images", required=True, description="the generated image from decoders step"),
+ InputParam(
+ name="output_type",
+ default="pil",
+ type_hint=str,
+ description="The type of the output images, can be 'pil', 'np', 'pt'",
+ ),
+ InputParam("mask_overlay_kwargs"),
+ ]
+
+ @staticmethod
+ def check_inputs(output_type, mask_overlay_kwargs):
+ if output_type not in ["pil", "np", "pt"]:
+ raise ValueError(f"Invalid output_type: {output_type}")
+
+ if mask_overlay_kwargs and output_type != "pil":
+ raise ValueError("only support output_type 'pil' for mask overlay")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs)
+
+ if block_state.mask_overlay_kwargs is None:
+ mask_overlay_kwargs = {}
+ else:
+ mask_overlay_kwargs = block_state.mask_overlay_kwargs
+
+ block_state.images = components.image_mask_processor.postprocess(
+ image=block_state.images,
+ **mask_overlay_kwargs,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py
new file mode 100644
index 000000000000..49acd2dc0295
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py
@@ -0,0 +1,684 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Tuple
+
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging
+from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # one timestep
+ block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
+ block_state.latent_model_input = block_state.latents
+ return components, block_state
+
+
+class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # one timestep
+
+ block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1)
+ block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
+ return components, block_state
+
+
+class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that runs the controlnet before the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "control_image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "controlnet_conditioning_scale",
+ type_hint=float,
+ description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "controlnet_keep",
+ required=True,
+ type_hint=List[float],
+ description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description=(
+ "All conditional model inputs for the denoiser. "
+ "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
+ ),
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int):
+ # cond_scale for the timestep (controlnet input)
+ if isinstance(block_state.controlnet_keep[i], list):
+ block_state.cond_scale = [
+ c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])
+ ]
+ else:
+ controlnet_cond_scale = block_state.controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
+
+ # run controlnet for the guidance batch
+ controlnet_block_samples = components.controlnet(
+ hidden_states=block_state.latent_model_input,
+ controlnet_cond=block_state.control_image_latents,
+ conditioning_scale=block_state.cond_scale,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ encoder_hidden_states=block_state.prompt_embeds,
+ encoder_hidden_states_mask=block_state.prompt_embeds_mask,
+ txt_seq_lens=block_state.txt_seq_lens,
+ return_dict=False,
+ )
+
+ block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples
+
+ return components, block_state
+
+
+class QwenImageLoopDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that denoise the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", QwenImageTransformer2DModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
+ ),
+ InputParam(
+ "img_shapes",
+ required=True,
+ type_hint=List[Tuple[int, int]],
+ description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ guider_inputs = {
+ "encoder_hidden_states": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
+ "encoder_hidden_states_mask": (
+ getattr(block_state, "prompt_embeds_mask", None),
+ getattr(block_state, "negative_prompt_embeds_mask", None),
+ ),
+ "txt_seq_lens": (
+ getattr(block_state, "txt_seq_lens", None),
+ getattr(block_state, "negative_txt_seq_lens", None),
+ ),
+ }
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+ guider_state = components.guider.prepare_inputs(guider_inputs)
+
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.transformer)
+ cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
+
+ # YiYi TODO: add cache context
+ guider_state_batch.noise_pred = components.transformer(
+ hidden_states=block_state.latent_model_input,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ **block_state.additional_cond_kwargs,
+ )[0]
+
+ components.guider.cleanup_models(components.transformer)
+
+ guider_output = components.guider(guider_state)
+
+ # apply guidance rescale
+ pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True)
+ pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True)
+ block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm)
+
+ return components, block_state
+
+
+class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that denoise the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", QwenImageTransformer2DModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
+ ),
+ InputParam(
+ "img_shapes",
+ required=True,
+ type_hint=List[Tuple[int, int]],
+ description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ guider_inputs = {
+ "encoder_hidden_states": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
+ "encoder_hidden_states_mask": (
+ getattr(block_state, "prompt_embeds_mask", None),
+ getattr(block_state, "negative_prompt_embeds_mask", None),
+ ),
+ "txt_seq_lens": (
+ getattr(block_state, "txt_seq_lens", None),
+ getattr(block_state, "negative_txt_seq_lens", None),
+ ),
+ }
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+ guider_state = components.guider.prepare_inputs(guider_inputs)
+
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.transformer)
+ cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
+
+ # YiYi TODO: add cache context
+ guider_state_batch.noise_pred = components.transformer(
+ hidden_states=block_state.latent_model_input,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ **block_state.additional_cond_kwargs,
+ )[0]
+
+ components.guider.cleanup_models(components.transformer)
+
+ guider_output = components.guider(guider_state)
+
+ pred = guider_output.pred[:, : block_state.latents.size(1)]
+ pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)]
+
+ # apply guidance rescale
+ pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True)
+ pred_norm = torch.norm(pred, dim=-1, keepdim=True)
+ block_state.noise_pred = pred * (pred_cond_norm / pred_norm)
+
+ return components, block_state
+
+
+class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that updates the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ block_state.latents = block_state.latents.to(latents_dtype)
+
+ return components, block_state
+
+
+class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that updates the latents using mask and image_latents for inpainting. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "mask",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "initial_noise",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ block_state.init_latents_proper = block_state.image_latents
+ if i < len(block_state.timesteps) - 1:
+ block_state.noise_timestep = block_state.timesteps[i + 1]
+ block_state.init_latents_proper = components.scheduler.scale_noise(
+ block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise
+ )
+
+ block_state.latents = (
+ 1 - block_state.mask
+ ) * block_state.init_latents_proper + block_state.mask * block_state.latents
+
+ return components, block_state
+
+
+class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def loop_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+
+ block_state.additional_cond_kwargs = {}
+
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# composing the denoising loops
+class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports text2image and image2image tasks for QwenImage."
+ )
+
+
+# composing the inpainting denoising loops
+class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks for QwenImage."
+ )
+
+
+# composing the controlnet denoising loops
+class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopBeforeDenoiserControlNet,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopBeforeDenoiserControlNet`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports text2img/img2img tasks with controlnet for QwenImage."
+ )
+
+
+# composing the controlnet denoising loops
+class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopBeforeDenoiserControlNet,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = [
+ "before_denoiser",
+ "before_denoiser_controlnet",
+ "denoiser",
+ "after_denoiser",
+ "after_denoiser_inpaint",
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopBeforeDenoiserControlNet`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks with controlnet for QwenImage."
+ )
+
+
+# composing the denoising loops
+class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageEditLoopBeforeDenoiser,
+ QwenImageEditLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageEditLoopBeforeDenoiser`\n"
+ " - `QwenImageEditLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports QwenImage Edit."
+ )
+
+
+class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageEditLoopBeforeDenoiser,
+ QwenImageEditLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageEditLoopBeforeDenoiser`\n"
+ " - `QwenImageEditLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks for QwenImage Edit."
+ )
diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py
new file mode 100644
index 000000000000..3b56981e5290
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py
@@ -0,0 +1,1085 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Optional, Union
+
+import PIL
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist
+from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
+from ...utils import logging
+from ...utils.torch_utils import unwrap_module
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+ return split_result
+
+
+def get_qwen_prompt_embeds(
+ text_encoder,
+ tokenizer,
+ prompt: Union[str, List[str]] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ prompt_template_encode_start_idx: int = 34,
+ tokenizer_max_length: int = 1024,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = prompt_template_encode
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = tokenizer(
+ txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+
+ split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+
+def get_qwen_prompt_embeds_edit(
+ text_encoder,
+ processor,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
+ prompt_template_encode_start_idx: int = 64,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = prompt_template_encode
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+
+ model_inputs = processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+
+def get_qwen_prompt_embeds_edit_plus(
+ text_encoder,
+ processor,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
+ prompt_template_encode_start_idx: int = 64,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if isinstance(image, list):
+ base_img_prompt = ""
+ for i, img in enumerate(image):
+ base_img_prompt += img_template_encode.format(i + 1)
+ elif image is not None:
+ base_img_prompt = img_template_encode.format(1)
+ else:
+ base_img_prompt = ""
+
+ template = prompt_template_encode
+
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(base_img_prompt + e) for e in prompt]
+
+ model_inputs = processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+ outputs = text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+ return prompt_embeds, encoder_attention_mask
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image
+def encode_vae_image(
+ image: torch.Tensor,
+ vae: AutoencoderKLQwenImage,
+ generator: torch.Generator,
+ device: torch.device,
+ dtype: torch.dtype,
+ latent_channels: int = 16,
+ sample_mode: str = "argmax",
+):
+ if not isinstance(image, torch.Tensor):
+ raise ValueError(f"Expected image to be a tensor, got {type(image)}.")
+
+ # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ image = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
+ latents_mean = (
+ torch.tensor(vae.config.latents_mean)
+ .view(1, latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(vae.config.latents_std)
+ .view(1, latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ image_latents = (image_latents - latents_mean) / latents_std
+
+ return image_latents
+
+
+class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
+ """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
+
+ This block resizes an input image tensor and exposes the resized result under configurable input and output
+ names. Use this when you need to wire the resize step to different image fields (e.g., "image",
+ "control_image")
+
+ Args:
+ input_name (str, optional): Name of the image field to read from the
+ pipeline state. Defaults to "image".
+ output_name (str, optional): Name of the resized image field to write
+ back to the pipeline state. Defaults to "resized_image".
+ """
+ if not isinstance(input_name, str) or not isinstance(output_name, str):
+ raise ValueError(
+ f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
+ )
+ self._image_input_name = input_name
+ self._resized_image_output_name = output_name
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_resize_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ images = getattr(block_state, self._image_input_name)
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if is_valid_image(images):
+ images = [images]
+
+ image_width, image_height = images[0].size
+ calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
+
+ resized_images = [
+ components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
+ for image in images
+ ]
+
+ setattr(block_state, self._resized_image_output_name, resized_images)
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ input_name: str = "image",
+ output_name: str = "resized_image",
+ vae_image_output_name: str = "vae_image",
+ ):
+ """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
+
+ This block resizes an input image or a list input images and exposes the resized result under configurable
+ input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
+ "image", "control_image")
+
+ Args:
+ input_name (str, optional): Name of the image field to read from the
+ pipeline state. Defaults to "image".
+ output_name (str, optional): Name of the resized image field to write
+ back to the pipeline state. Defaults to "resized_image".
+ vae_image_output_name (str, optional): Name of the image field
+ to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus
+ processes the input image(s) differently for the VL and the VAE.
+ """
+ if not isinstance(input_name, str) or not isinstance(output_name, str):
+ raise ValueError(
+ f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
+ )
+ self.condition_image_size = 384 * 384
+ self._image_input_name = input_name
+ self._resized_image_output_name = output_name
+ self._vae_image_output_name = vae_image_output_name
+ super().__init__()
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return super().intermediate_outputs + [
+ OutputParam(
+ name=self._vae_image_output_name,
+ type_hint=List[PIL.Image.Image],
+ description="The images to be processed which will be further used by the VAE encoder.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ images = getattr(block_state, self._image_input_name)
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if (
+ not isinstance(images, torch.Tensor)
+ and isinstance(images, PIL.Image.Image)
+ and not isinstance(images, list)
+ ):
+ images = [images]
+
+ # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s
+ condition_images = []
+ vae_images = []
+ for img in images:
+ image_width, image_height = img.size
+ condition_width, condition_height, _ = calculate_dimensions(
+ self.condition_image_size, image_width / image_height
+ )
+ condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width))
+ vae_images.append(img)
+
+ setattr(block_state, self._resized_image_output_name, condition_images)
+ setattr(block_state, self._vae_image_output_name, vae_images)
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageTextEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"),
+ ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=34),
+ ConfigSpec(name="tokenizer_max_length", default=1024),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
+ InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
+ InputParam(
+ name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The prompt embeddings",
+ ),
+ OutputParam(
+ name="prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The encoder attention mask",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings mask",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(prompt, negative_prompt, max_sequence_length):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ negative_prompt is not None
+ and not isinstance(negative_prompt, str)
+ and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length)
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds(
+ components.text_encoder,
+ components.tokenizer,
+ prompt=block_state.prompt,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ tokenizer_max_length=components.config.tokenizer_max_length,
+ device=device,
+ )
+
+ block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
+
+ block_state.negative_prompt_embeds = None
+ block_state.negative_prompt_embeds_mask = None
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or ""
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
+ components.text_encoder,
+ components.tokenizer,
+ prompt=negative_prompt,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ tokenizer_max_length=components.config.tokenizer_max_length,
+ device=device,
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[
+ :, : block_state.max_sequence_length
+ ]
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[
+ :, : block_state.max_sequence_length
+ ]
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
+ ComponentSpec("processor", Qwen2VLProcessor),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=64),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
+ InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
+ InputParam(
+ name="resized_image",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image prompt to encode, should be resized using resize step",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The prompt embeddings",
+ ),
+ OutputParam(
+ name="prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The encoder attention mask",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings mask",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(prompt, negative_prompt):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ negative_prompt is not None
+ and not isinstance(negative_prompt, str)
+ and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.prompt, block_state.negative_prompt)
+
+ device = components._execution_device
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit(
+ components.text_encoder,
+ components.processor,
+ prompt=block_state.prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ block_state.negative_prompt_embeds = None
+ block_state.negative_prompt_embeds_mask = None
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or " "
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
+ components.text_encoder,
+ components.processor,
+ prompt=negative_prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
+ model_name = "qwenimage"
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(
+ name="img_template_encode",
+ default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=64),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.prompt, block_state.negative_prompt)
+
+ device = components._execution_device
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus(
+ components.text_encoder,
+ components.processor,
+ prompt=block_state.prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ img_template_encode=components.config.img_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ block_state.negative_prompt_embeds = None
+ block_state.negative_prompt_embeds_mask = None
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or " "
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
+ get_qwen_prompt_embeds_edit_plus(
+ components.text_encoder,
+ components.processor,
+ prompt=negative_prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ img_template_encode=components.config.img_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_mask_processor",
+ InpaintProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("mask_image", required=True),
+ InputParam("resized_image"),
+ InputParam("image"),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("padding_mask_crop"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="processed_image"),
+ OutputParam(name="processed_mask_image"),
+ OutputParam(
+ name="mask_overlay_kwargs",
+ type_hint=Dict,
+ description="The kwargs for the postprocess step to apply the mask overlay",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("resized_image and image cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = (
+ components.image_mask_processor.preprocess(
+ image=image,
+ mask=block_state.mask_image,
+ height=height,
+ width=width,
+ padding_mask_crop=block_state.padding_mask_crop,
+ )
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="processed_image"),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("resized_image and image cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image,
+ height=height,
+ width=width,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
+ model_name = "qwenimage-edit-plus"
+ vae_image_size = 1024 * 1024
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.vae_image is None and block_state.image is None:
+ raise ValueError("`vae_image` and `image` cannot be None at the same time")
+
+ if block_state.vae_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image, height=height, width=width
+ )
+ else:
+ width, height = block_state.vae_image[0].size
+ image = block_state.vae_image
+
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image, height=height, width=width
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ input_name: str = "processed_image",
+ output_name: str = "image_latents",
+ ):
+ """Initialize a VAE encoder step for converting images to latent representations.
+
+ Both the input and output names are configurable so this block can be configured to process to different image
+ inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
+
+ Args:
+ input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
+ Examples: "processed_image" or "processed_control_image"
+ output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
+ Examples: "image_latents" or "control_image_latents"
+
+ Examples:
+ # Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
+
+ # Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
+ input_name="processed_control_image", output_name="control_image_latents"
+ )
+ """
+ self._image_input_name = input_name
+ self._image_latents_output_name = output_name
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(self._image_input_name, required=True),
+ InputParam("generator"),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ self._image_latents_output_name,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ dtype = components.vae.dtype
+
+ image = getattr(block_state, self._image_input_name)
+
+ # Encode image into latents
+ image_latents = encode_vae_image(
+ image=image,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ )
+ setattr(block_state, self._image_latents_output_name, image_latents)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ComponentSpec(
+ "control_image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam("control_image", required=True),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("generator"),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "control_image_latents",
+ type_hint=torch.Tensor,
+ description="The latents representing the control image",
+ )
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor)
+
+ device = components._execution_device
+ dtype = components.vae.dtype
+
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+
+ controlnet = unwrap_module(components.controlnet)
+ if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list):
+ block_state.control_image = [block_state.control_image]
+
+ if isinstance(controlnet, QwenImageMultiControlNetModel):
+ block_state.control_image_latents = []
+ for control_image_ in block_state.control_image:
+ control_image_ = components.control_image_processor.preprocess(
+ image=control_image_,
+ height=height,
+ width=width,
+ )
+
+ control_image_latents_ = encode_vae_image(
+ image=control_image_,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ sample_mode="sample",
+ )
+ block_state.control_image_latents.append(control_image_latents_)
+
+ elif isinstance(controlnet, QwenImageControlNetModel):
+ control_image = components.control_image_processor.preprocess(
+ image=block_state.control_image,
+ height=height,
+ width=width,
+ )
+ block_state.control_image_latents = encode_vae_image(
+ image=control_image,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ sample_mode="sample",
+ )
+
+ else:
+ raise ValueError(
+ f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}"
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py
new file mode 100644
index 000000000000..2b229c040b89
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py
@@ -0,0 +1,443 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Tuple
+
+import torch
+
+from ...models import QwenImageMultiControlNetModel
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+def repeat_tensor_to_batch_size(
+ input_name: str,
+ input_tensor: torch.Tensor,
+ batch_size: int,
+ num_images_per_prompt: int = 1,
+) -> torch.Tensor:
+ """Repeat tensor elements to match the final batch size.
+
+ This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
+ by repeating each element along dimension 0.
+
+ The input tensor must have batch size 1 or batch_size. The function will:
+ - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
+ - If batch size equals batch_size: repeat each element num_images_per_prompt times
+
+ Args:
+ input_name (str): Name of the input tensor (used for error messages)
+ input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
+ batch_size (int): The base batch size (number of prompts)
+ num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
+
+ Returns:
+ torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
+
+ Raises:
+ ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
+
+ Examples:
+ tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
+ batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
+ [4, 3]
+
+ tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
+ tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
+ - shape: [4, 3]
+ """
+ # make sure input is a tensor
+ if not isinstance(input_tensor, torch.Tensor):
+ raise ValueError(f"`{input_name}` must be a tensor")
+
+ # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
+ if input_tensor.shape[0] == 1:
+ repeat_by = batch_size * num_images_per_prompt
+ elif input_tensor.shape[0] == batch_size:
+ repeat_by = num_images_per_prompt
+ else:
+ raise ValueError(
+ f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
+ )
+
+ # expand the tensor to match the batch_size * num_images_per_prompt
+ input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
+
+ return input_tensor
+
+
+def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]:
+ """Calculate image dimensions from latent tensor dimensions.
+
+ This function converts latent space dimensions to image space dimensions by multiplying the latent height and width
+ by the VAE scale factor.
+
+ Args:
+ latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
+ Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
+ vae_scale_factor (int): The scale factor used by the VAE to compress images.
+ Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
+
+ Returns:
+ Tuple[int, int]: The calculated image dimensions as (height, width)
+
+ Raises:
+ ValueError: If latents tensor doesn't have 4 or 5 dimensions
+
+ """
+ # make sure the latents are not packed
+ if latents.ndim != 4 and latents.ndim != 5:
+ raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")
+
+ latent_height, latent_width = latents.shape[-2:]
+
+ height = latent_height * vae_scale_factor
+ width = latent_width * vae_scale_factor
+
+ return height, width
+
+
+class QwenImageTextInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ summary_section = (
+ "Text input processing step that standardizes text embeddings for the pipeline.\n"
+ "This step:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
+ )
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps."
+
+ return summary_section + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
+ InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
+ InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
+ InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(
+ prompt_embeds,
+ prompt_embeds_mask,
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ ):
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None")
+
+ if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None:
+ raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`")
+
+ if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]:
+ raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
+
+ elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]:
+ raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`")
+
+ elif (
+ negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]
+ ):
+ raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ prompt_embeds=block_state.prompt_embeds,
+ prompt_embeds_mask=block_state.prompt_embeds_mask,
+ negative_prompt_embeds=block_state.negative_prompt_embeds,
+ negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask,
+ )
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len
+ )
+
+ if block_state.negative_prompt_embeds is not None:
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageInputsDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
+
+ This step handles multiple common tasks to prepare inputs for the denoising step:
+ 1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
+ 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
+
+ This is a dynamic block that allows you to configure which inputs to process.
+
+ Args:
+ image_latent_inputs (List[str], optional): Names of image latent tensors to process.
+ These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
+ list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
+ additional_batch_inputs (List[str], optional):
+ Names of additional conditional input tensors to expand batch size. These tensors will only have their
+ batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
+ Defaults to []. Examples: ["processed_mask_image"]
+
+ Examples:
+ # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
+
+ # Configure to process multiple image latent inputs
+ QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
+
+ # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
+ image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
+ )
+ """
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ # Functionality section
+ summary_section = (
+ "Input processing step that:\n"
+ " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size"
+ )
+
+ # Inputs info
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ # Add image latent inputs
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ # Add additional batch inputs
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
+ OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
+ ]
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify the image latent tensor
+ image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageControlNetInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="control_image_latents", required=True),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if isinstance(components.controlnet, QwenImageMultiControlNetModel):
+ control_image_latents = []
+ # loop through each control_image_latents
+ for i, control_image_latents_ in enumerate(block_state.control_image_latents):
+ # 1. update height/width if not provided
+ height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ # 2. pack
+ control_image_latents_ = components.pachifier.pack_latents(control_image_latents_)
+
+ # 3. repeat to match the batch size
+ control_image_latents_ = repeat_tensor_to_batch_size(
+ input_name=f"control_image_latents[{i}]",
+ input_tensor=control_image_latents_,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ control_image_latents.append(control_image_latents_)
+
+ block_state.control_image_latents = control_image_latents
+
+ else:
+ # 1. update height/width if not provided
+ height, width = calculate_dimension_from_latents(
+ block_state.control_image_latents, components.vae_scale_factor
+ )
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ # 2. pack
+ block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents)
+
+ # 3. repeat to match the batch size
+ block_state.control_image_latents = repeat_tensor_to_batch_size(
+ input_name="control_image_latents",
+ input_tensor=block_state.control_image_latents,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ block_state.control_image_latents = block_state.control_image_latents
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
new file mode 100644
index 000000000000..419894164389
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
@@ -0,0 +1,1035 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import (
+ QwenImageControlNetBeforeDenoiserStep,
+ QwenImageCreateMaskLatentsStep,
+ QwenImageEditRoPEInputsStep,
+ QwenImagePrepareLatentsStep,
+ QwenImagePrepareLatentsWithStrengthStep,
+ QwenImageRoPEInputsStep,
+ QwenImageSetTimestepsStep,
+ QwenImageSetTimestepsWithStrengthStep,
+)
+from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
+from .denoise import (
+ QwenImageControlNetDenoiseStep,
+ QwenImageDenoiseStep,
+ QwenImageEditDenoiseStep,
+ QwenImageEditInpaintDenoiseStep,
+ QwenImageInpaintControlNetDenoiseStep,
+ QwenImageInpaintDenoiseStep,
+ QwenImageLoopBeforeDenoiserControlNet,
+)
+from .encoders import (
+ QwenImageControlNetVaeEncoderStep,
+ QwenImageEditPlusProcessImagesInputStep,
+ QwenImageEditPlusResizeDynamicStep,
+ QwenImageEditPlusTextEncoderStep,
+ QwenImageEditResizeDynamicStep,
+ QwenImageEditTextEncoderStep,
+ QwenImageInpaintProcessImagesInputStep,
+ QwenImageProcessImagesInputStep,
+ QwenImageTextEncoderStep,
+ QwenImageVaeEncoderDynamicStep,
+)
+from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
+
+
+logger = logging.get_logger(__name__)
+
+# 1. QwenImage
+
+## 1.1 QwenImage/text2image
+
+#### QwenImage/decode
+#### (standard decode step works for most tasks except for inpaint)
+QwenImageDecodeBlocks = InsertableDict(
+ [
+ ("decode", QwenImageDecoderStep()),
+ ("postprocess", QwenImageProcessImagesOutputStep()),
+ ]
+)
+
+
+class QwenImageDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageDecodeBlocks.values()
+ block_names = QwenImageDecodeBlocks.keys()
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image."
+
+
+#### QwenImage/text2image presets
+TEXT2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("input", QwenImageTextInputsStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 1.2 QwenImage/inpaint
+
+#### QwenImage/inpaint vae encoder
+QwenImageInpaintVaeEncoderBlocks = InsertableDict(
+ [
+ (
+ "preprocess",
+ QwenImageInpaintProcessImagesInputStep,
+ ), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintVaeEncoderBlocks.values()
+ block_names = QwenImageInpaintVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step is used for processing image and mask inputs for inpainting tasks. It:\n"
+ " - Resizes the image to the target size, based on `height` and `width`.\n"
+ " - Processes and updates `image` and `mask_image`.\n"
+ " - Creates `image_latents`."
+ )
+
+
+#### QwenImage/inpaint inputs
+QwenImageInpaintInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ (
+ "additional_inputs",
+ QwenImageInputsDynamicStep(
+ image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
+ ),
+ ),
+ ]
+)
+
+
+class QwenImageInpaintInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintInputBlocks.values()
+ block_names = QwenImageInpaintInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+# QwenImage/inpaint prepare latents
+QwenImageInpaintPrepareLatentsBlocks = InsertableDict(
+ [
+ ("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("create_mask_latents", QwenImageCreateMaskLatentsStep()),
+ ]
+)
+
+
+class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintPrepareLatentsBlocks.values()
+ block_names = QwenImageInpaintPrepareLatentsBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
+ " - Add noise to the image latents to create the latents input for the denoiser.\n"
+ " - Create the pachified latents `mask` based on the processedmask image.\n"
+ )
+
+
+#### QwenImage/inpaint decode
+QwenImageInpaintDecodeBlocks = InsertableDict(
+ [
+ ("decode", QwenImageDecoderStep()),
+ ("postprocess", QwenImageInpaintProcessImagesOutputStep()),
+ ]
+)
+
+
+class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintDecodeBlocks.values()
+ block_names = QwenImageInpaintDecodeBlocks.keys()
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
+
+
+#### QwenImage/inpaint presets
+INPAINT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageInpaintVaeEncoderStep()),
+ ("input", QwenImageInpaintInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageInpaintDenoiseStep()),
+ ("decode", QwenImageInpaintDecodeStep()),
+ ]
+)
+
+
+## 1.3 QwenImage/img2img
+
+#### QwenImage/img2img vae encoder
+QwenImageImg2ImgVaeEncoderBlocks = InsertableDict(
+ [
+ ("preprocess", QwenImageProcessImagesInputStep()),
+ ("encode", QwenImageVaeEncoderDynamicStep()),
+ ]
+)
+
+
+class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ block_classes = QwenImageImg2ImgVaeEncoderBlocks.values()
+ block_names = QwenImageImg2ImgVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+#### QwenImage/img2img inputs
+QwenImageImg2ImgInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
+ ]
+)
+
+
+class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageImg2ImgInputBlocks.values()
+ block_names = QwenImageImg2ImgInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+#### QwenImage/img2img presets
+IMAGE2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageImg2ImgVaeEncoderStep()),
+ ("input", QwenImageImg2ImgInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 1.4 QwenImage/controlnet
+
+#### QwenImage/controlnet presets
+CONTROLNET_BLOCKS = InsertableDict(
+ [
+ ("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image
+ ("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet
+ (
+ "controlnet_before_denoise",
+ QwenImageControlNetBeforeDenoiserStep(),
+ ), # before denoise step (after set_timesteps step)
+ (
+ "controlnet_denoise_loop_before",
+ QwenImageLoopBeforeDenoiserControlNet(),
+ ), # controlnet loop step (insert before the denoiseloop_denoiser)
+ ]
+)
+
+
+## 1.5 QwenImage/auto encoders
+
+
+#### for inpaint and img2img tasks
+class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
+ block_names = ["inpaint", "img2img"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
+ + " - if `mask_image` or `image` is not provided, step will be skipped."
+ )
+
+
+# for controlnet tasks
+class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetVaeEncoderStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
+ + " - if `control_image` is not provided, step will be skipped."
+ )
+
+
+## 1.6 QwenImage/auto inputs
+
+
+# text2image/inpaint/img2img
+class QwenImageAutoInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep]
+ block_names = ["inpaint", "img2img", "text2image"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n"
+ + " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
+ )
+
+
+# controlnet
+class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetInputsStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet input step that prepare the control_image_latents input.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - if `control_image_latents` is not provided, step will be skipped."
+ )
+
+
+## 1.7 QwenImage/auto before denoise step
+# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step
+
+# QwenImage/text2image before denoise
+QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values()
+ block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task."
+
+
+# QwenImage/inpaint before denoise
+QwenImageInpaintBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintBeforeDenoiseBlocks.values()
+ block_names = QwenImageInpaintBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
+
+
+# QwenImage/img2img before denoise
+QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values()
+ block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
+
+
+# auto before_denoise step for text2image, inpaint, img2img tasks
+class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageInpaintBeforeDenoiseStep,
+ QwenImageImg2ImgBeforeDenoiseStep,
+ QwenImageText2ImageBeforeDenoiseStep,
+ ]
+ block_names = ["inpaint", "img2img", "text2image"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n"
+ + " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
+ )
+
+
+# auto before_denoise step for controlnet tasks
+class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetBeforeDenoiserStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet before denoise step that prepare the controlnet input.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - if `control_image_latents` is not provided, step will be skipped."
+ )
+
+
+## 1.8 QwenImage/auto denoise
+
+
+# auto denoise step for controlnet tasks: works for all tasks with controlnet
+class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep]
+ block_names = ["inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet step during the denoising process. \n"
+ " This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n"
+ + " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n"
+ + " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n"
+ )
+
+
+# auto denoise step for everything: works for all tasks with or without controlnet
+class QwenImageAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageControlNetAutoDenoiseStep,
+ QwenImageInpaintDenoiseStep,
+ QwenImageDenoiseStep,
+ ]
+ block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["control_image_latents", "mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ " This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n"
+ + " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n"
+ + " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n"
+ )
+
+
+## 1.9 QwenImage/auto decode
+# auto decode step for inpaint and text2image tasks
+
+
+class QwenImageAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
+ block_names = ["inpaint_decode", "decode"]
+ block_trigger_inputs = ["mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Decode step that decode the latents into images. \n"
+ " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
+ + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
+ + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
+ )
+
+
+class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageAutoInputStep,
+ QwenImageOptionalControlNetInputStep,
+ QwenImageAutoBeforeDenoiseStep,
+ QwenImageOptionalControlNetBeforeDenoiseStep,
+ QwenImageAutoDenoiseStep,
+ ]
+ block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageOptionalControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
+ + " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n"
+ + " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
+ + " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings"
+ )
+
+
+## 1.10 QwenImage/auto block & presets
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageAutoVaeEncoderStep()),
+ ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
+ ("denoise", QwenImageCoreDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ block_classes = AUTO_BLOCKS.values()
+ block_names = AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ + "- for image-to-image generation, you need to provide `image`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`"
+ )
+
+
+# 2. QwenImage-Edit
+
+## 2.1 QwenImage-Edit/edit
+
+#### QwenImage-Edit/edit vl encoder: take both image and text prompts
+QwenImageEditVLEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()),
+ ("encode", QwenImageEditTextEncoderStep()),
+ ]
+)
+
+
+class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditVLEncoderBlocks.values()
+ block_names = QwenImageEditVLEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Edit VL encoder step that encode the image an text prompts together."
+
+
+#### QwenImage-Edit/edit vae encoder
+QwenImageEditVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step
+ ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditVaeEncoderBlocks.values()
+ block_names = QwenImageEditVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that encode the image inputs into their latent representations."
+
+
+#### QwenImage-Edit/edit input
+QwenImageEditInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
+ ]
+)
+
+
+class QwenImageEditInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInputBlocks.values()
+ block_names = QwenImageEditInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the edit denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs: \n"
+ " - `image_latents`.\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+#### QwenImage/edit presets
+EDIT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditVaeEncoderStep()),
+ ("input", QwenImageEditInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 2.2 QwenImage-Edit/edit inpaint
+
+#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step
+QwenImageEditInpaintVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image
+ (
+ "preprocess",
+ QwenImageInpaintProcessImagesInputStep,
+ ), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
+ (
+ "encode",
+ QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
+ ), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInpaintVaeEncoderBlocks.values()
+ block_names = QwenImageEditInpaintVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
+ " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
+ " - process the resized image and mask image.\n"
+ " - create image latents."
+ )
+
+
+#### QwenImage-Edit/edit inpaint presets
+EDIT_INPAINT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditInpaintVaeEncoderStep()),
+ ("input", QwenImageInpaintInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditInpaintDenoiseStep()),
+ ("decode", QwenImageInpaintDecodeStep()),
+ ]
+)
+
+
+## 2.3 QwenImage-Edit/auto encoders
+
+
+class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageEditInpaintVaeEncoderStep,
+ QwenImageEditVaeEncoderStep,
+ ]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations. \n"
+ " This is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
+ + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
+ + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
+ + " - if `mask_image` or `image` is not provided, step will be skipped."
+ )
+
+
+## 2.4 QwenImage-Edit/auto inputs
+class QwenImageEditAutoInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the edit denoising step.\n"
+ + " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
+ + " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n"
+ + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 2.5 QwenImage-Edit/auto before denoise
+# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step
+
+#### QwenImage-Edit/edit before denoise
+QwenImageEditBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditBeforeDenoiseBlocks.values()
+ block_names = QwenImageEditBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
+
+
+#### QwenImage-Edit/edit inpaint before denoise
+QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values()
+ block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task."
+
+
+# auto before_denoise step for edit and edit_inpaint tasks
+class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditInpaintBeforeDenoiseStep,
+ QwenImageEditBeforeDenoiseStep,
+ ]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n"
+ + " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ + " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped."
+ )
+
+
+## 2.6 QwenImage-Edit/auto denoise
+
+
+class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit"
+
+ block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep]
+ block_names = ["inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ + "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n"
+ + " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 2.7 QwenImage-Edit/auto blocks & presets
+
+
+class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditAutoInputStep,
+ QwenImageEditAutoBeforeDenoiseStep,
+ QwenImageEditAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support edit (img2img) and edit inpainting workflow for QwenImage Edit:\n"
+ + " - When `processed_mask_image` is provided, it will be used for edit inpainting task.\n"
+ + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
+ )
+
+
+EDIT_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
+ ("denoise", QwenImageEditCoreDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = EDIT_AUTO_BLOCKS.values()
+ block_names = EDIT_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
+ + "- for edit (img2img) generation, you need to provide `image`\n"
+ + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ )
+
+
+#################### QwenImage Edit Plus #####################
+
+# 3. QwenImage-Edit Plus
+
+## 3.1 QwenImage-Edit Plus / edit
+
+#### QwenImage-Edit Plus vl encoder: take both image and text prompts
+QwenImageEditPlusVLEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditPlusResizeDynamicStep()),
+ ("encode", QwenImageEditPlusTextEncoderStep()),
+ ]
+)
+
+
+class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditPlusVLEncoderBlocks.values()
+ block_names = QwenImageEditPlusVLEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Edit Plus VL encoder step that encode the image an text prompts together."
+
+
+#### QwenImage-Edit Plus vae encoder
+QwenImageEditPlusVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
+ ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
+ block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that encode the image inputs into their latent representations."
+
+
+#### QwenImage Edit Plus presets
+EDIT_PLUS_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditPlusVLEncoderStep()),
+ ("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
+ ("input", QwenImageEditInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+# auto before_denoise step for edit tasks
+class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [QwenImageEditBeforeDenoiseStep]
+ block_names = ["edit"]
+ block_trigger_inputs = ["image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for edit (img2img) task.\n"
+ + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ + " - if `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 3.2 QwenImage-Edit Plus/auto encoders
+
+
+class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageEditPlusVaeEncoderStep,
+ ]
+ block_names = ["edit"]
+ block_trigger_inputs = ["image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations. \n"
+ " This is an auto pipeline block that works for edit task.\n"
+ + " - `QwenImageEditPlusVaeEncoderStep` (edit) is used when `image` is provided.\n"
+ + " - if `image` is not provided, step will be skipped."
+ )
+
+
+## 3.3 QwenImage-Edit/auto blocks & presets
+
+
+class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [
+ QwenImageEditAutoInputStep,
+ QwenImageEditPlusAutoBeforeDenoiseStep,
+ QwenImageEditAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageEditPlusAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support edit (img2img) workflow for QwenImage Edit Plus:\n"
+ + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
+ )
+
+
+EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditPlusVLEncoderStep()),
+ ("vae_encoder", QwenImageEditPlusAutoVaeEncoderStep()),
+ ("denoise", QwenImageEditPlusCoreDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
+ block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for edit (img2img) and edit tasks using QwenImage-Edit Plus.\n"
+ + "- for edit (img2img) generation, you need to provide `image`\n"
+ )
+
+
+# 3. all block presets supported in QwenImage, QwenImage-Edit, QwenImage-Edit Plus
+
+
+ALL_BLOCKS = {
+ "text2image": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "edit": EDIT_BLOCKS,
+ "edit_inpaint": EDIT_INPAINT_BLOCKS,
+ "edit_plus": EDIT_PLUS_BLOCKS,
+ "inpaint": INPAINT_BLOCKS,
+ "controlnet": CONTROLNET_BLOCKS,
+ "auto": AUTO_BLOCKS,
+ "edit_auto": EDIT_AUTO_BLOCKS,
+ "edit_plus_auto": EDIT_PLUS_AUTO_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
new file mode 100644
index 000000000000..59e1a13a5db2
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
@@ -0,0 +1,205 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import QwenImageLoraLoaderMixin
+from ..modular_pipeline import ModularPipeline
+
+
+class QwenImagePachifier(ConfigMixin):
+ """
+ A class to pack and unpack latents for QwenImage.
+ """
+
+ config_name = "config.json"
+
+ @register_to_config
+ def __init__(self, patch_size: int = 2):
+ super().__init__()
+
+ def pack_latents(self, latents):
+ if latents.ndim != 4 and latents.ndim != 5:
+ raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}")
+
+ if latents.ndim == 4:
+ latents = latents.unsqueeze(2)
+
+ batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape
+ patch_size = self.config.patch_size
+
+ if latent_height % patch_size != 0 or latent_width % patch_size != 0:
+ raise ValueError(
+ f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
+ )
+
+ latents = latents.view(
+ batch_size,
+ num_channels_latents,
+ latent_height // patch_size,
+ patch_size,
+ latent_width // patch_size,
+ patch_size,
+ )
+ latents = latents.permute(
+ 0, 2, 4, 1, 3, 5
+ ) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size
+ latents = latents.reshape(
+ batch_size,
+ (latent_height // patch_size) * (latent_width // patch_size),
+ num_channels_latents * patch_size * patch_size,
+ )
+
+ return latents
+
+ def unpack_latents(self, latents, height, width, vae_scale_factor=8):
+ if latents.ndim != 3:
+ raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
+
+ batch_size, num_patches, channels = latents.shape
+ patch_size = self.config.patch_size
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = patch_size * (int(height) // (vae_scale_factor * patch_size))
+ width = patch_size * (int(width) // (vae_scale_factor * patch_size))
+
+ latents = latents.view(
+ batch_size,
+ height // patch_size,
+ width // patch_size,
+ channels // (patch_size * patch_size),
+ patch_size,
+ patch_size,
+ )
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
+
+ return latents
+
+
+class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
+ """
+ A ModularPipeline for QwenImage.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "QwenImageAutoBlocks"
+
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ return 128
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if hasattr(self, "transformer") and self.transformer is not None:
+ num_channels_latents = self.transformer.config.in_channels // 4
+ return num_channels_latents
+
+ @property
+ def is_guidance_distilled(self):
+ is_guidance_distilled = False
+ if hasattr(self, "transformer") and self.transformer is not None:
+ is_guidance_distilled = self.transformer.config.guidance_embeds
+ return is_guidance_distilled
+
+ @property
+ def requires_unconditional_embeds(self):
+ requires_unconditional_embeds = False
+
+ if hasattr(self, "guider") and self.guider is not None:
+ requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
+
+ return requires_unconditional_embeds
+
+
+class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
+ """
+ A ModularPipeline for QwenImage-Edit.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "QwenImageEditAutoBlocks"
+
+ # YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step.
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ return 128
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if hasattr(self, "transformer") and self.transformer is not None:
+ num_channels_latents = self.transformer.config.in_channels // 4
+ return num_channels_latents
+
+ @property
+ def is_guidance_distilled(self):
+ is_guidance_distilled = False
+ if hasattr(self, "transformer") and self.transformer is not None:
+ is_guidance_distilled = self.transformer.config.guidance_embeds
+ return is_guidance_distilled
+
+ @property
+ def requires_unconditional_embeds(self):
+ requires_unconditional_embeds = False
+
+ if hasattr(self, "guider") and self.guider is not None:
+ requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
+
+ return requires_unconditional_embeds
+
+
+class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline):
+ """
+ A ModularPipeline for QwenImage-Edit Plus.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "QwenImageEditPlusAutoBlocks"
diff --git a/src/diffusers/modular_pipelines/qwenimage/node_utils.py b/src/diffusers/modular_pipelines/qwenimage/node_utils.py
new file mode 100644
index 000000000000..3230ece68abc
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/node_utils.py
@@ -0,0 +1,95 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# mellon nodes
+QwenImage_NODE_TYPES_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ "vae",
+ ],
+ "outputs": [
+ "controlnet_out",
+ ],
+ "block_names": ["controlnet_vae_encoder"],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ "controlnet",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ },
+}
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py
new file mode 100644
index 000000000000..59ec46dc6d36
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py
@@ -0,0 +1,77 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "AUTO_BLOCKS",
+ "CONTROLNET_BLOCKS",
+ "IMAGE2IMAGE_BLOCKS",
+ "INPAINT_BLOCKS",
+ "IP_ADAPTER_BLOCKS",
+ "TEXT2IMAGE_BLOCKS",
+ "StableDiffusionXLAutoBlocks",
+ "StableDiffusionXLAutoControlnetStep",
+ "StableDiffusionXLAutoDecodeStep",
+ "StableDiffusionXLAutoIPAdapterStep",
+ "StableDiffusionXLAutoVaeEncoderStep",
+ ]
+ _import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .encoders import (
+ StableDiffusionXLTextEncoderStep,
+ )
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ AUTO_BLOCKS,
+ CONTROLNET_BLOCKS,
+ IMAGE2IMAGE_BLOCKS,
+ INPAINT_BLOCKS,
+ IP_ADAPTER_BLOCKS,
+ TEXT2IMAGE_BLOCKS,
+ StableDiffusionXLAutoBlocks,
+ StableDiffusionXLAutoControlnetStep,
+ StableDiffusionXLAutoDecodeStep,
+ StableDiffusionXLAutoIPAdapterStep,
+ StableDiffusionXLAutoVaeEncoderStep,
+ )
+ from .modular_pipeline import StableDiffusionXLModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
new file mode 100644
index 000000000000..70cbf0c1c78d
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
@@ -0,0 +1,1874 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, List, Optional, Tuple, Union
+
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
+from ...models.controlnets.multicontrolnet import MultiControlNetModel
+from ...schedulers import EulerDiscreteScheduler
+from ...utils import logging
+from ...utils.torch_utils import randn_tensor, unwrap_module
+from ..modular_pipeline import (
+ ModularPipelineBlocks,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import StableDiffusionXLModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
+# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
+# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
+# configuration of guider is.
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def prepare_latents_img2img(
+ vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
+):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}")
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ latents_mean = latents_std = None
+ if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
+ latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if vae.config.force_upcast:
+ image = image.float()
+ vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
+ )
+
+ init_latents = [
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(vae.encode(image), generator=generator)
+
+ if vae.config.force_upcast:
+ vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
+ latents_std = latents_std.to(device=device, dtype=dtype)
+ init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std
+ else:
+ init_latents = vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+
+class StableDiffusionXLInputStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Input processing step that:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
+ "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
+ "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
+ "have a final batch_size of batch_size * num_images_per_prompt."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_images_per_prompt", default=1),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "negative_pooled_prompt_embeds",
+ description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.",
+ ),
+ InputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
+ description="negative text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
+ description="negative pooled text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
+ description="image embeddings for IP-Adapter",
+ ),
+ OutputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
+ description="negative image embeddings for IP-Adapter",
+ ),
+ ]
+
+ def check_inputs(self, components, block_state):
+ if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
+ if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {block_state.negative_prompt_embeds.shape}."
+ )
+
+ if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list):
+ raise ValueError("`ip_adapter_embeds` must be a list")
+
+ if block_state.negative_ip_adapter_embeds is not None and not isinstance(
+ block_state.negative_ip_adapter_embeds, list
+ ):
+ raise ValueError("`negative_ip_adapter_embeds` must be a list")
+
+ if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None:
+ for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds):
+ if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape:
+ raise ValueError(
+ "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but"
+ f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`"
+ f" {block_state.negative_ip_adapter_embeds[i].shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ if block_state.negative_prompt_embeds is not None:
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, -1
+ )
+
+ if block_state.negative_pooled_prompt_embeds is not None:
+ block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, -1
+ )
+
+ if block_state.ip_adapter_embeds is not None:
+ for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds):
+ block_state.ip_adapter_embeds[i] = torch.cat(
+ [ip_adapter_embed] * block_state.num_images_per_prompt, dim=0
+ )
+
+ if block_state.negative_ip_adapter_embeds is not None:
+ for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds):
+ block_state.negative_ip_adapter_embeds[i] = torch.cat(
+ [negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n"
+ + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ InputParam("denoising_end"),
+ InputParam("strength", default=0.3),
+ InputParam("denoising_start"),
+ # YiYi TODO: do we need num_images_per_prompt here?
+ InputParam("num_images_per_prompt", default=1),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ OutputParam(
+ "latent_timestep",
+ type_hint=torch.Tensor,
+ description="The timestep that represents the initial noise level for image-to-image generation",
+ ),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self->components
+ def get_timesteps(components, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+
+ timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :]
+ if hasattr(components.scheduler, "set_begin_index"):
+ components.scheduler.set_begin_index(t_start * components.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ else:
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ discrete_timestep_cutoff = int(
+ round(
+ components.scheduler.config.num_train_timesteps
+ - (denoising_start * components.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
+ if components.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ t_start = len(components.scheduler.timesteps) - num_inference_steps
+ timesteps = components.scheduler.timesteps[t_start:]
+ if hasattr(components.scheduler, "set_begin_index"):
+ components.scheduler.set_begin_index(t_start)
+ return timesteps, num_inference_steps
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.device = components._execution_device
+
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ components.scheduler,
+ block_state.num_inference_steps,
+ block_state.device,
+ block_state.timesteps,
+ block_state.sigmas,
+ )
+
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ block_state.timesteps, block_state.num_inference_steps = self.get_timesteps(
+ components,
+ block_state.num_inference_steps,
+ block_state.strength,
+ block_state.device,
+ denoising_start=block_state.denoising_start
+ if denoising_value_valid(block_state.denoising_start)
+ else None,
+ )
+ block_state.latent_timestep = block_state.timesteps[:1].repeat(
+ block_state.batch_size * block_state.num_images_per_prompt
+ )
+
+ if (
+ block_state.denoising_end is not None
+ and isinstance(block_state.denoising_end, float)
+ and block_state.denoising_end > 0
+ and block_state.denoising_end < 1
+ ):
+ block_state.discrete_timestep_cutoff = int(
+ round(
+ components.scheduler.config.num_train_timesteps
+ - (block_state.denoising_end * components.scheduler.config.num_train_timesteps)
+ )
+ )
+ block_state.num_inference_steps = len(
+ list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))
+ )
+ block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps]
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the scheduler's timesteps for inference"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ InputParam("denoising_end"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.device = components._execution_device
+
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ components.scheduler,
+ block_state.num_inference_steps,
+ block_state.device,
+ block_state.timesteps,
+ block_state.sigmas,
+ )
+
+ if (
+ block_state.denoising_end is not None
+ and isinstance(block_state.denoising_end, float)
+ and block_state.denoising_end > 0
+ and block_state.denoising_end < 1
+ ):
+ block_state.discrete_timestep_cutoff = int(
+ round(
+ components.scheduler.config.num_train_timesteps
+ - (block_state.denoising_end * components.scheduler.config.num_train_timesteps)
+ )
+ )
+ block_state.num_inference_steps = len(
+ list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))
+ )
+ block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps]
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the latents for the inpainting process"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("latents"),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("denoising_start"),
+ InputParam(
+ "strength",
+ default=0.9999,
+ description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` "
+ "will be used as a starting point, adding more noise to it the larger the `strength`. The number of "
+ "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will "
+ "be maximum and the denoising process will run for the full number of iterations specified in "
+ "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
+ "`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
+ ),
+ InputParam("generator"),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ "latent_timestep",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
+ ),
+ InputParam(
+ "mask",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The mask for the inpainting generation. Can be generated in vae_encode step.",
+ ),
+ InputParam(
+ "masked_image_latents",
+ type_hint=torch.Tensor,
+ description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.",
+ ),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ ),
+ OutputParam(
+ "noise",
+ type_hint=torch.Tensor,
+ description="The noise added to the image latents, used for inpainting generation",
+ ),
+ ]
+
+ # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components
+ # YiYi TODO: update the _encode_vae_image so that we can use #Coped from
+ @staticmethod
+ def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator):
+ latents_mean = latents_std = None
+ if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
+ latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
+
+ dtype = image.dtype
+ if components.vae.config.force_upcast:
+ image = image.float()
+ components.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
+
+ if components.vae.config.force_upcast:
+ components.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
+ latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
+ image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
+ else:
+ image_latents = components.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument
+ def prepare_latents_inpaint(
+ self,
+ components,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // components.vae_scale_factor,
+ int(width) // components.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ elif latents is None and not is_strength_max:
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(components, image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * components.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents, noise, image_latents)
+
+ return outputs
+
+ # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
+ # do not accept do_classifier_free_guidance
+ def prepare_mask_latents(
+ self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+ block_state.device = components._execution_device
+
+ block_state.is_strength_max = block_state.strength == 1.0
+
+ # for non-inpainting specific unet, we do not need masked_image_latents
+ if hasattr(components, "unet") and components.unet is not None:
+ if components.unet.config.in_channels == 4:
+ block_state.masked_image_latents = None
+
+ block_state.add_noise = True if block_state.denoising_start is None else False
+
+ block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
+ block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
+
+ block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
+ components,
+ block_state.batch_size * block_state.num_images_per_prompt,
+ components.num_channels_latents,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.latents,
+ image=block_state.image_latents,
+ timestep=block_state.latent_timestep,
+ is_strength_max=block_state.is_strength_max,
+ add_noise=block_state.add_noise,
+ )
+
+ # 7. Prepare mask latent variables
+ block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
+ components,
+ block_state.mask,
+ block_state.masked_image_latents,
+ block_state.batch_size * block_state.num_images_per_prompt,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the latents for the image-to-image generation process"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("latents"),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("denoising_start"),
+ InputParam("generator"),
+ InputParam(
+ "latent_timestep",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+ block_state.device = components._execution_device
+ block_state.add_noise = True if block_state.denoising_start is None else False
+ if block_state.latents is None:
+ block_state.latents = prepare_latents_img2img(
+ components.vae,
+ components.scheduler,
+ block_state.image_latents,
+ block_state.latent_timestep,
+ block_state.batch_size,
+ block_state.num_images_per_prompt,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.add_noise,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("vae", AutoencoderKL),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Prepare latents step that prepares the latents for the text-to-image generation process"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("latents"),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("generator"),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ )
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ if (
+ block_state.height is not None
+ and block_state.height % components.vae_scale_factor != 0
+ or block_state.width is not None
+ and block_state.width % components.vae_scale_factor != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}."
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self->comp
+ def prepare_latents(comp, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // comp.vae_scale_factor,
+ int(width) // comp.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * comp.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if block_state.dtype is None:
+ block_state.dtype = components.vae.dtype
+
+ block_state.device = components._execution_device
+
+ self.check_inputs(components, block_state)
+
+ block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor
+ block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor
+ block_state.num_channels_latents = components.num_channels_latents
+ block_state.latents = self.prepare_latents(
+ components,
+ block_state.batch_size * block_state.num_images_per_prompt,
+ block_state.num_channels_latents,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.latents,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec("requires_aesthetics_score", False),
+ ]
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("unet", UNet2DConditionModel),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the additional conditioning for the image-to-image/inpainting generation process"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("original_size"),
+ InputParam("target_size"),
+ InputParam("negative_original_size"),
+ InputParam("negative_target_size"),
+ InputParam("crops_coords_top_left", default=(0, 0)),
+ InputParam("negative_crops_coords_top_left", default=(0, 0)),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("aesthetic_score", default=6.0),
+ InputParam("negative_aesthetic_score", default=2.0),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "add_time_ids",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="The time ids to condition the denoising process",
+ ),
+ OutputParam(
+ "negative_add_time_ids",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="The negative time ids to condition the denoising process",
+ ),
+ OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self->components
+ def _get_add_time_ids(
+ components,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if components.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+
+ block_state.vae_scale_factor = components.vae_scale_factor
+
+ block_state.height, block_state.width = block_state.latents.shape[-2:]
+ block_state.height = block_state.height * block_state.vae_scale_factor
+ block_state.width = block_state.width * block_state.vae_scale_factor
+
+ block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
+ block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
+
+ block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
+
+ if block_state.negative_original_size is None:
+ block_state.negative_original_size = block_state.original_size
+ if block_state.negative_target_size is None:
+ block_state.negative_target_size = block_state.target_size
+
+ block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids(
+ components,
+ block_state.original_size,
+ block_state.crops_coords_top_left,
+ block_state.target_size,
+ block_state.aesthetic_score,
+ block_state.negative_aesthetic_score,
+ block_state.negative_original_size,
+ block_state.negative_crops_coords_top_left,
+ block_state.negative_target_size,
+ dtype=block_state.pooled_prompt_embeds.dtype,
+ text_encoder_projection_dim=block_state.text_encoder_projection_dim,
+ )
+ block_state.add_time_ids = block_state.add_time_ids.repeat(
+ block_state.batch_size * block_state.num_images_per_prompt, 1
+ ).to(device=block_state.device)
+ block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(
+ block_state.batch_size * block_state.num_images_per_prompt, 1
+ ).to(device=block_state.device)
+
+ # Optionally get Guidance Scale Embedding for LCM
+ block_state.timestep_cond = None
+ if (
+ hasattr(components, "unet")
+ and components.unet is not None
+ and components.unet.config.time_cond_proj_dim is not None
+ ):
+ # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
+ block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(
+ block_state.batch_size * block_state.num_images_per_prompt
+ )
+ block_state.timestep_cond = self.get_guidance_scale_embedding(
+ block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
+ ).to(device=block_state.device, dtype=block_state.latents.dtype)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the additional conditioning for the text-to-image generation process"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("unet", UNet2DConditionModel),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("original_size"),
+ InputParam("target_size"),
+ InputParam("negative_original_size"),
+ InputParam("negative_target_size"),
+ InputParam("crops_coords_top_left", default=(0, 0)),
+ InputParam("negative_crops_coords_top_left", default=(0, 0)),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "add_time_ids",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="The time ids to condition the denoising process",
+ ),
+ OutputParam(
+ "negative_add_time_ids",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="The negative time ids to condition the denoising process",
+ ),
+ OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self->components
+ def _get_add_time_ids(
+ components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+
+ block_state.height, block_state.width = block_state.latents.shape[-2:]
+ block_state.height = block_state.height * components.vae_scale_factor
+ block_state.width = block_state.width * components.vae_scale_factor
+
+ block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
+ block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
+
+ block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
+
+ block_state.add_time_ids = self._get_add_time_ids(
+ components,
+ block_state.original_size,
+ block_state.crops_coords_top_left,
+ block_state.target_size,
+ block_state.pooled_prompt_embeds.dtype,
+ text_encoder_projection_dim=block_state.text_encoder_projection_dim,
+ )
+ if block_state.negative_original_size is not None and block_state.negative_target_size is not None:
+ block_state.negative_add_time_ids = self._get_add_time_ids(
+ components,
+ block_state.negative_original_size,
+ block_state.negative_crops_coords_top_left,
+ block_state.negative_target_size,
+ block_state.pooled_prompt_embeds.dtype,
+ text_encoder_projection_dim=block_state.text_encoder_projection_dim,
+ )
+ else:
+ block_state.negative_add_time_ids = block_state.add_time_ids
+
+ block_state.add_time_ids = block_state.add_time_ids.repeat(
+ block_state.batch_size * block_state.num_images_per_prompt, 1
+ ).to(device=block_state.device)
+ block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(
+ block_state.batch_size * block_state.num_images_per_prompt, 1
+ ).to(device=block_state.device)
+
+ # Optionally get Guidance Scale Embedding for LCM
+ block_state.timestep_cond = None
+ if (
+ hasattr(components, "unet")
+ and components.unet is not None
+ and components.unet.config.time_cond_proj_dim is not None
+ ):
+ # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
+ block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(
+ block_state.batch_size * block_state.num_images_per_prompt
+ )
+ block_state.timestep_cond = self.get_guidance_scale_embedding(
+ block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
+ ).to(device=block_state.device, dtype=block_state.latents.dtype)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("controlnet", ControlNetModel),
+ ComponentSpec(
+ "control_image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "step that prepare inputs for controlnet"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("control_image", required=True),
+ InputParam("control_guidance_start", default=0.0),
+ InputParam("control_guidance_end", default=1.0),
+ InputParam("controlnet_conditioning_scale", default=1.0),
+ InputParam("guess_mode", default=False),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "crops_coords",
+ type_hint=Optional[Tuple[int]],
+ description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"),
+ OutputParam(
+ "control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"
+ ),
+ OutputParam(
+ "control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"
+ ),
+ OutputParam(
+ "conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"
+ ),
+ OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"),
+ OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
+ ]
+
+ # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
+ # 1. return image without apply any guidance
+ # 2. add crops_coords and resize_mode to preprocess()
+ @staticmethod
+ def prepare_control_image(
+ components,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ crops_coords=None,
+ ):
+ if crops_coords is not None:
+ image = components.control_image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill"
+ ).to(dtype=torch.float32)
+ else:
+ image = components.control_image_processor.preprocess(image, height=height, width=width).to(
+ dtype=torch.float32
+ )
+
+ image_batch_size = image.shape[0]
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+ image = image.to(device=device, dtype=dtype)
+ return image
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # (1) prepare controlnet inputs
+ block_state.device = components._execution_device
+ block_state.height, block_state.width = block_state.latents.shape[-2:]
+ block_state.height = block_state.height * components.vae_scale_factor
+ block_state.width = block_state.width * components.vae_scale_factor
+
+ controlnet = unwrap_module(components.controlnet)
+
+ # (1.1)
+ # control_guidance_start/control_guidance_end (align format)
+ if not isinstance(block_state.control_guidance_start, list) and isinstance(
+ block_state.control_guidance_end, list
+ ):
+ block_state.control_guidance_start = len(block_state.control_guidance_end) * [
+ block_state.control_guidance_start
+ ]
+ elif not isinstance(block_state.control_guidance_end, list) and isinstance(
+ block_state.control_guidance_start, list
+ ):
+ block_state.control_guidance_end = len(block_state.control_guidance_start) * [
+ block_state.control_guidance_end
+ ]
+ elif not isinstance(block_state.control_guidance_start, list) and not isinstance(
+ block_state.control_guidance_end, list
+ ):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ block_state.control_guidance_start, block_state.control_guidance_end = (
+ mult * [block_state.control_guidance_start],
+ mult * [block_state.control_guidance_end],
+ )
+
+ # (1.2)
+ # controlnet_conditioning_scale (align format)
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
+ block_state.controlnet_conditioning_scale, float
+ ):
+ block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(
+ controlnet.nets
+ )
+
+ # (1.3)
+ # global_pool_conditions
+ block_state.global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ # (1.4)
+ # guess_mode
+ block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
+
+ # (1.5)
+ # control_image
+ if isinstance(controlnet, ControlNetModel):
+ block_state.control_image = self.prepare_control_image(
+ components,
+ image=block_state.control_image,
+ width=block_state.width,
+ height=block_state.height,
+ batch_size=block_state.batch_size * block_state.num_images_per_prompt,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ device=block_state.device,
+ dtype=controlnet.dtype,
+ crops_coords=block_state.crops_coords,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in block_state.control_image:
+ control_image = self.prepare_control_image(
+ components,
+ image=control_image_,
+ width=block_state.width,
+ height=block_state.height,
+ batch_size=block_state.batch_size * block_state.num_images_per_prompt,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ device=block_state.device,
+ dtype=controlnet.dtype,
+ crops_coords=block_state.crops_coords,
+ )
+
+ control_images.append(control_image)
+
+ block_state.control_image = control_images
+ else:
+ assert False
+
+ # (1.6)
+ # controlnet_keep
+ block_state.controlnet_keep = []
+ for i in range(len(block_state.timesteps)):
+ keeps = [
+ 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
+ for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
+ ]
+ block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ block_state.controlnet_cond = block_state.control_image
+ block_state.conditioning_scale = block_state.controlnet_conditioning_scale
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("controlnet", ControlNetUnionModel),
+ ComponentSpec(
+ "control_image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "step that prepares inputs for the ControlNetUnion model"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("control_image", required=True),
+ InputParam("control_mode", required=True),
+ InputParam("control_guidance_start", default=0.0),
+ InputParam("control_guidance_end", default=1.0),
+ InputParam("controlnet_conditioning_scale", default=1.0),
+ InputParam("guess_mode", default=False),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ "dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of model tensor inputs. Can be generated in input step.",
+ ),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "crops_coords",
+ type_hint=Optional[Tuple[int]],
+ description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"),
+ OutputParam(
+ "control_type_idx",
+ type_hint=List[int],
+ description="The control mode indices",
+ kwargs_type="controlnet_kwargs",
+ ),
+ OutputParam(
+ "control_type",
+ type_hint=torch.Tensor,
+ description="The control type tensor that specifies which control type is active",
+ kwargs_type="controlnet_kwargs",
+ ),
+ OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"),
+ OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"),
+ OutputParam(
+ "conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"
+ ),
+ OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"),
+ OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
+ ]
+
+ # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
+ # 1. return image without apply any guidance
+ # 2. add crops_coords and resize_mode to preprocess()
+ @staticmethod
+ def prepare_control_image(
+ components,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ crops_coords=None,
+ ):
+ if crops_coords is not None:
+ image = components.control_image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill"
+ ).to(dtype=torch.float32)
+ else:
+ image = components.control_image_processor.preprocess(image, height=height, width=width).to(
+ dtype=torch.float32
+ )
+
+ image_batch_size = image.shape[0]
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+ image = image.to(device=device, dtype=dtype)
+ return image
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ controlnet = unwrap_module(components.controlnet)
+
+ device = components._execution_device
+ dtype = block_state.dtype or components.controlnet.dtype
+
+ block_state.height, block_state.width = block_state.latents.shape[-2:]
+ block_state.height = block_state.height * components.vae_scale_factor
+ block_state.width = block_state.width * components.vae_scale_factor
+
+ # control_guidance_start/control_guidance_end (align format)
+ if not isinstance(block_state.control_guidance_start, list) and isinstance(
+ block_state.control_guidance_end, list
+ ):
+ block_state.control_guidance_start = len(block_state.control_guidance_end) * [
+ block_state.control_guidance_start
+ ]
+ elif not isinstance(block_state.control_guidance_end, list) and isinstance(
+ block_state.control_guidance_start, list
+ ):
+ block_state.control_guidance_end = len(block_state.control_guidance_start) * [
+ block_state.control_guidance_end
+ ]
+
+ # guess_mode
+ block_state.global_pool_conditions = controlnet.config.global_pool_conditions
+ block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
+
+ # control_image
+ if not isinstance(block_state.control_image, list):
+ block_state.control_image = [block_state.control_image]
+ # control_mode
+ if not isinstance(block_state.control_mode, list):
+ block_state.control_mode = [block_state.control_mode]
+
+ if len(block_state.control_image) != len(block_state.control_mode):
+ raise ValueError("Expected len(control_image) == len(control_type)")
+
+ # control_type
+ block_state.num_control_type = controlnet.config.num_control_type
+ block_state.control_type = [0 for _ in range(block_state.num_control_type)]
+ for control_idx in block_state.control_mode:
+ block_state.control_type[control_idx] = 1
+ block_state.control_type = torch.Tensor(block_state.control_type)
+
+ block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype)
+ repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0]
+ block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0)
+
+ # prepare control_image
+ for idx, _ in enumerate(block_state.control_image):
+ block_state.control_image[idx] = self.prepare_control_image(
+ components,
+ image=block_state.control_image[idx],
+ width=block_state.width,
+ height=block_state.height,
+ batch_size=block_state.batch_size * block_state.num_images_per_prompt,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ device=device,
+ dtype=dtype,
+ crops_coords=block_state.crops_coords,
+ )
+ block_state.height, block_state.width = block_state.control_image[idx].shape[-2:]
+
+ # controlnet_keep
+ block_state.controlnet_keep = []
+ for i in range(len(block_state.timesteps)):
+ block_state.controlnet_keep.append(
+ 1.0
+ - float(
+ i / len(block_state.timesteps) < block_state.control_guidance_start
+ or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end
+ )
+ )
+ block_state.control_type_idx = block_state.control_mode
+ block_state.controlnet_cond = block_state.control_image
+ block_state.conditioning_scale = block_state.controlnet_conditioning_scale
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
new file mode 100644
index 000000000000..6e0307260d1d
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
@@ -0,0 +1,198 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL
+from ...utils import deprecate, logging
+from ..modular_pipeline import (
+ ModularPipelineBlocks,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the denoised latents into images"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("output_type", default="pil"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The denoised latents from the denoising step",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components
+ def upcast_vae(components):
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
+ )
+ components.vae.to(dtype=torch.float32)
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if not block_state.output_type == "latent":
+ latents = block_state.latents
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
+
+ if block_state.needs_upcasting:
+ self.upcast_vae(components)
+ latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
+ elif latents.dtype != components.vae.dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ components.vae = components.vae.to(latents.dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ block_state.has_latents_mean = (
+ hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
+ )
+ block_state.has_latents_std = (
+ hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
+ )
+ if block_state.has_latents_mean and block_state.has_latents_std:
+ block_state.latents_mean = (
+ torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ block_state.latents_std = (
+ torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = (
+ latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
+ )
+ else:
+ latents = latents / components.vae.config.scaling_factor
+
+ block_state.images = components.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if block_state.needs_upcasting:
+ components.vae.to(dtype=torch.float16)
+ else:
+ block_state.images = block_state.latents
+
+ # apply watermark if available
+ if hasattr(components, "watermark") and components.watermark is not None:
+ block_state.images = components.watermark.apply_watermark(block_state.images)
+
+ block_state.images = components.image_processor.postprocess(
+ block_state.images, output_type=block_state.output_type
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return (
+ "A post-processing step that overlays the mask on the image (inpainting task only).\n"
+ + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("image"),
+ InputParam("mask_image"),
+ InputParam("padding_mask_crop"),
+ InputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="The generated images from the decode step",
+ ),
+ InputParam(
+ "crops_coords",
+ type_hint=Tuple[int, int],
+ description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
+ block_state.images = [
+ components.image_processor.apply_overlay(
+ block_state.mask_image, block_state.image, i, block_state.crops_coords
+ )
+ for i in block_state.images
+ ]
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
new file mode 100644
index 000000000000..862315e59169
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
@@ -0,0 +1,798 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, List, Optional, Tuple
+
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...models import ControlNetModel, UNet2DConditionModel
+from ...schedulers import EulerDiscreteScheduler
+from ...utils import logging
+from ..modular_pipeline import (
+ BlockState,
+ LoopSequentialPipelineBlocks,
+ ModularPipelineBlocks,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import StableDiffusionXLModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# YiYi experimenting composible denoise loop
+# loop step (1): prepare latent input for denoiser
+class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepare the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
+
+ return components, block_state
+
+
+# loop step (1): prepare latent input for denoiser (with inpainting)
+class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object"
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "mask",
+ type_hint=Optional[torch.Tensor],
+ description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
+ ),
+ InputParam(
+ "masked_image_latents",
+ type_hint=Optional[torch.Tensor],
+ description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ num_channels_unet = components.num_channels_unet
+ if num_channels_unet == 9:
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
+ if block_state.mask is None or block_state.masked_image_latents is None:
+ raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet")
+ num_channels_latents = block_state.latents.shape[1]
+ num_channels_mask = block_state.mask.shape[1]
+ num_channels_masked_image = block_state.masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects"
+ f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ " `components.unet` or your `mask_image` or `image` input."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ self.check_inputs(components, block_state)
+
+ block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
+ if components.num_channels_unet == 9:
+ block_state.scaled_latents = torch.cat(
+ [block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
+ )
+
+ return components, block_state
+
+
+# loop step (2): denoise the latents with guidance
+class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents with guidance. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("cross_attention_kwargs"),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "timestep_cond",
+ type_hint=Optional[torch.Tensor],
+ description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description=(
+ "All conditional model inputs that need to be prepared with guider. "
+ "It should contain prompt_embeds/negative_prompt_embeds, "
+ "add_time_ids/negative_add_time_ids, "
+ "pooled_prompt_embeds/negative_pooled_prompt_embeds, "
+ "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
+ "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ ),
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(
+ self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int
+ ) -> PipelineState:
+ # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
+ # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
+ guider_inputs = {
+ "prompt_embeds": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
+ "time_ids": (
+ getattr(block_state, "add_time_ids", None),
+ getattr(block_state, "negative_add_time_ids", None),
+ ),
+ "text_embeds": (
+ getattr(block_state, "pooled_prompt_embeds", None),
+ getattr(block_state, "negative_pooled_prompt_embeds", None),
+ ),
+ "image_embeds": (
+ getattr(block_state, "ip_adapter_embeds", None),
+ getattr(block_state, "negative_ip_adapter_embeds", None),
+ ),
+ }
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = components.guider.prepare_inputs(guider_inputs)
+
+ # run the denoiser for each guidance batch
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.unet)
+ cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
+ prompt_embeds = cond_kwargs.pop("prompt_embeds")
+
+ # Predict the noise residual
+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
+ guider_state_batch.noise_pred = components.unet(
+ block_state.scaled_latents,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=block_state.timestep_cond,
+ cross_attention_kwargs=block_state.cross_attention_kwargs,
+ added_cond_kwargs=cond_kwargs,
+ return_dict=False,
+ )[0]
+ components.guider.cleanup_models(components.unet)
+
+ # Perform guidance
+ block_state.noise_pred = components.guider(guider_state)[0]
+
+ return components, block_state
+
+
+# loop step (2): denoise the latents with guidance (with controlnet)
+class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ComponentSpec("controlnet", ControlNetModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that denoise the latents with guidance (with controlnet). "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("cross_attention_kwargs"),
+ InputParam(
+ "controlnet_cond",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "conditioning_scale",
+ type_hint=float,
+ description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "guess_mode",
+ required=True,
+ type_hint=bool,
+ description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "controlnet_keep",
+ required=True,
+ type_hint=List[float],
+ description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "timestep_cond",
+ type_hint=Optional[torch.Tensor],
+ description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description=(
+ "All conditional model inputs that need to be prepared with guider. "
+ "It should contain prompt_embeds/negative_prompt_embeds, "
+ "add_time_ids/negative_add_time_ids, "
+ "pooled_prompt_embeds/negative_pooled_prompt_embeds, "
+ "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
+ "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ ),
+ ),
+ InputParam(
+ kwargs_type="controlnet_kwargs",
+ description=(
+ "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )"
+ "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ ),
+ ),
+ ]
+
+ @staticmethod
+ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
+ accepted_kwargs = set(inspect.signature(func).parameters.keys())
+ extra_kwargs = {}
+ for key, value in kwargs.items():
+ if key in accepted_kwargs and key not in exclude_kwargs:
+ extra_kwargs[key] = value
+
+ return extra_kwargs
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ extra_controlnet_kwargs = self.prepare_extra_kwargs(
+ components.controlnet.forward, **block_state.controlnet_kwargs
+ )
+
+ # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
+ # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
+ guider_inputs = {
+ "prompt_embeds": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
+ "time_ids": (
+ getattr(block_state, "add_time_ids", None),
+ getattr(block_state, "negative_add_time_ids", None),
+ ),
+ "text_embeds": (
+ getattr(block_state, "pooled_prompt_embeds", None),
+ getattr(block_state, "negative_pooled_prompt_embeds", None),
+ ),
+ "image_embeds": (
+ getattr(block_state, "ip_adapter_embeds", None),
+ getattr(block_state, "negative_ip_adapter_embeds", None),
+ ),
+ }
+
+ # cond_scale for the timestep (controlnet input)
+ if isinstance(block_state.controlnet_keep[i], list):
+ block_state.cond_scale = [
+ c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])
+ ]
+ else:
+ controlnet_cond_scale = block_state.conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
+
+ # default controlnet output/unet input for guess mode + conditional path
+ block_state.down_block_res_samples_zeros = None
+ block_state.mid_block_res_sample_zeros = None
+
+ # guided denoiser step
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = components.guider.prepare_inputs(guider_inputs)
+
+ # run the denoiser for each guidance batch
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.unet)
+
+ # Prepare additional conditionings
+ added_cond_kwargs = {
+ "text_embeds": guider_state_batch.text_embeds,
+ "time_ids": guider_state_batch.time_ids,
+ }
+ if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds
+
+ # Prepare controlnet additional conditionings
+ controlnet_added_cond_kwargs = {
+ "text_embeds": guider_state_batch.text_embeds,
+ "time_ids": guider_state_batch.time_ids,
+ }
+ # run controlnet for the guidance batch
+ if block_state.guess_mode and not components.guider.is_conditional:
+ # guider always run uncond batch first, so these tensors should be set already
+ down_block_res_samples = block_state.down_block_res_samples_zeros
+ mid_block_res_sample = block_state.mid_block_res_sample_zeros
+ else:
+ down_block_res_samples, mid_block_res_sample = components.controlnet(
+ block_state.scaled_latents,
+ t,
+ encoder_hidden_states=guider_state_batch.prompt_embeds,
+ controlnet_cond=block_state.controlnet_cond,
+ conditioning_scale=block_state.cond_scale,
+ guess_mode=block_state.guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ **extra_controlnet_kwargs,
+ )
+
+ # assign it to block_state so it will be available for the uncond guidance batch
+ if block_state.down_block_res_samples_zeros is None:
+ block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples]
+ if block_state.mid_block_res_sample_zeros is None:
+ block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample)
+
+ # Predict the noise
+ # store the noise_pred in guider_state_batch so we can apply guidance across all batches
+ guider_state_batch.noise_pred = components.unet(
+ block_state.scaled_latents,
+ t,
+ encoder_hidden_states=guider_state_batch.prompt_embeds,
+ timestep_cond=block_state.timestep_cond,
+ cross_attention_kwargs=block_state.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+ components.guider.cleanup_models(components.unet)
+
+ # Perform guidance
+ block_state.noise_pred = components.guider(guider_state)[0]
+
+ return components, block_state
+
+
+# loop step (3): scheduler step to update latents
+class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that update the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("eta", default=0.0),
+ InputParam("generator"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
+
+ # YiYi TODO: move this out of here
+ @staticmethod
+ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
+ accepted_kwargs = set(inspect.signature(func).parameters.keys())
+ extra_kwargs = {}
+ for key, value in kwargs.items():
+ if key in accepted_kwargs and key not in exclude_kwargs:
+ extra_kwargs[key] = value
+
+ return extra_kwargs
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ block_state.extra_step_kwargs = self.prepare_extra_kwargs(
+ components.scheduler.step, generator=block_state.generator, eta=block_state.eta
+ )
+
+ # Perform scheduler step using the predicted output
+ block_state.latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ **block_state.extra_step_kwargs,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != block_state.latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ block_state.latents = block_state.latents.to(block_state.latents_dtype)
+
+ return components, block_state
+
+
+# loop step (3): scheduler step to update latents (with inpainting)
+class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that update the latents (for inpainting workflow only). "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("eta", default=0.0),
+ InputParam("generator"),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "mask",
+ type_hint=Optional[torch.Tensor],
+ description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
+ ),
+ InputParam(
+ "noise",
+ type_hint=Optional[torch.Tensor],
+ description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "image_latents",
+ type_hint=Optional[torch.Tensor],
+ description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
+
+ @staticmethod
+ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
+ accepted_kwargs = set(inspect.signature(func).parameters.keys())
+ extra_kwargs = {}
+ for key, value in kwargs.items():
+ if key in accepted_kwargs and key not in exclude_kwargs:
+ extra_kwargs[key] = value
+
+ return extra_kwargs
+
+ def check_inputs(self, components, block_state):
+ if components.num_channels_unet == 4:
+ if block_state.image_latents is None:
+ raise ValueError(f"image_latents is required for this step {self.__class__.__name__}")
+ if block_state.mask is None:
+ raise ValueError(f"mask is required for this step {self.__class__.__name__}")
+ if block_state.noise is None:
+ raise ValueError(f"noise is required for this step {self.__class__.__name__}")
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ self.check_inputs(components, block_state)
+
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ block_state.extra_step_kwargs = self.prepare_extra_kwargs(
+ components.scheduler.step, generator=block_state.generator, eta=block_state.eta
+ )
+
+ # Perform scheduler step using the predicted output
+ block_state.latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ **block_state.extra_step_kwargs,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != block_state.latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ block_state.latents = block_state.latents.to(block_state.latents_dtype)
+
+ # adjust latent for inpainting
+ if components.num_channels_unet == 4:
+ block_state.init_latents_proper = block_state.image_latents
+ if i < len(block_state.timesteps) - 1:
+ block_state.noise_timestep = block_state.timesteps[i + 1]
+ block_state.init_latents_proper = components.scheduler.add_noise(
+ block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep])
+ )
+
+ block_state.latents = (
+ 1 - block_state.mask
+ ) * block_state.init_latents_proper + block_state.mask * block_state.latents
+
+ return components, block_state
+
+
+# the loop wrapper that iterates over the timesteps
+class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ]
+
+ @property
+ def loop_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
+ if block_state.disable_guidance:
+ components.guider.disable()
+ else:
+ components.guider.enable()
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# composing the denoising loops
+class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [
+ StableDiffusionXLLoopBeforeDenoiser,
+ StableDiffusionXLLoopDenoiser,
+ StableDiffusionXLLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `StableDiffusionXLLoopBeforeDenoiser`\n"
+ " - `StableDiffusionXLLoopDenoiser`\n"
+ " - `StableDiffusionXLLoopAfterDenoiser`\n"
+ "This block supports both text2img and img2img tasks."
+ )
+
+
+# control_cond
+class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [
+ StableDiffusionXLLoopBeforeDenoiser,
+ StableDiffusionXLControlNetLoopDenoiser,
+ StableDiffusionXLLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents with controlnet. \n"
+ "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `StableDiffusionXLLoopBeforeDenoiser`\n"
+ " - `StableDiffusionXLControlNetLoopDenoiser`\n"
+ " - `StableDiffusionXLLoopAfterDenoiser`\n"
+ "This block supports using controlnet for both text2img and img2img tasks."
+ )
+
+
+# mask
+class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [
+ StableDiffusionXLInpaintLoopBeforeDenoiser,
+ StableDiffusionXLLoopDenoiser,
+ StableDiffusionXLInpaintLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents(for inpainting task only). \n"
+ "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
+ " - `StableDiffusionXLLoopDenoiser`\n"
+ " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
+ "This block onlysupports inpainting tasks."
+ )
+
+
+# control_cond + mask
+class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [
+ StableDiffusionXLInpaintLoopBeforeDenoiser,
+ StableDiffusionXLControlNetLoopDenoiser,
+ StableDiffusionXLInpaintLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n"
+ "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
+ " - `StableDiffusionXLControlNetLoopDenoiser`\n"
+ " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
+ "This block only supports using controlnet for inpainting tasks."
+ )
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
new file mode 100644
index 000000000000..90b254b6f5d4
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
@@ -0,0 +1,887 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple
+
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import StableDiffusionXLModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return (
+ "IP Adapter step that prepares ip adapter image embeddings.\n"
+ "Note that this step only prepares the embeddings - in order for it to work correctly, "
+ "you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n"
+ "See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
+ " for more details"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
+ ComponentSpec(
+ "feature_extractor",
+ CLIPImageProcessor,
+ config=FrozenDict({"size": 224, "crop_size": 224}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "ip_adapter_image",
+ PipelineImageInput,
+ required=True,
+ description="The image(s) to be used as ip adapter",
+ )
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
+ OutputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=torch.Tensor,
+ description="Negative IP adapter image embeddings",
+ ),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components
+ def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(components.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = components.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = components.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = components.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ components,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ num_images_per_prompt,
+ prepare_unconditional_embeds,
+ ):
+ image_embeds = []
+ if prepare_unconditional_embeds:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ components, single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if prepare_unconditional_embeds:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if prepare_unconditional_embeds:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if prepare_unconditional_embeds:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
+ block_state.device = components._execution_device
+
+ block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
+ components,
+ ip_adapter_image=block_state.ip_adapter_image,
+ ip_adapter_image_embeds=None,
+ device=block_state.device,
+ num_images_per_prompt=1,
+ prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
+ )
+ if block_state.prepare_unconditional_embeds:
+ block_state.negative_ip_adapter_embeds = []
+ for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
+ negative_image_embeds, image_embeds = image_embeds.chunk(2)
+ block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
+ block_state.ip_adapter_embeds[i] = image_embeds
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", CLIPTextModel),
+ ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
+ ComponentSpec("tokenizer", CLIPTokenizer),
+ ComponentSpec("tokenizer_2", CLIPTokenizer),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [ConfigSpec("force_zeros_for_empty_prompt", True)]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("prompt"),
+ InputParam("prompt_2"),
+ InputParam("negative_prompt"),
+ InputParam("negative_prompt_2"),
+ InputParam("cross_attention_kwargs"),
+ InputParam("clip_skip"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="negative text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="negative pooled text embeddings used to guide the image generation",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(block_state):
+ if block_state.prompt is not None and (
+ not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
+ ):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
+ elif block_state.prompt_2 is not None and (
+ not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
+ ):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
+
+ @staticmethod
+ def encode_prompt(
+ components,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prepare_unconditional_embeds: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prepare_unconditional_embeds (`bool`):
+ whether to use prepare unconditional embeddings or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or components._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
+ components._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if components.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(components.text_encoder, lora_scale)
+
+ if components.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(components.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = (
+ [components.tokenizer, components.tokenizer_2]
+ if components.tokenizer is not None
+ else [components.tokenizer_2]
+ )
+ text_encoders = (
+ [components.text_encoder, components.text_encoder_2]
+ if components.text_encoder is not None
+ else [components.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(components, TextualInversionLoaderMixin):
+ prompt = components.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
+ if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif prepare_unconditional_embeds and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(components, TextualInversionLoaderMixin):
+ negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if components.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if prepare_unconditional_embeds:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if components.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(
+ dtype=components.text_encoder_2.dtype, device=device
+ )
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if prepare_unconditional_embeds:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if components.text_encoder is not None:
+ if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(components.text_encoder, lora_scale)
+
+ if components.text_encoder_2 is not None:
+ if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(components.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ # Get inputs and intermediates
+ block_state = self.get_block_state(state)
+ self.check_inputs(block_state)
+
+ block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
+ block_state.device = components._execution_device
+
+ # Encode input prompt
+ block_state.text_encoder_lora_scale = (
+ block_state.cross_attention_kwargs.get("scale", None)
+ if block_state.cross_attention_kwargs is not None
+ else None
+ )
+ (
+ block_state.prompt_embeds,
+ block_state.negative_prompt_embeds,
+ block_state.pooled_prompt_embeds,
+ block_state.negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ components,
+ block_state.prompt,
+ block_state.prompt_2,
+ block_state.device,
+ 1,
+ block_state.prepare_unconditional_embeds,
+ block_state.negative_prompt,
+ block_state.negative_prompt_2,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ lora_scale=block_state.text_encoder_lora_scale,
+ clip_skip=block_state.clip_skip,
+ )
+ # Add outputs
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return "Vae Encoder step that encode the input image into a latent representation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("image", required=True),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("generator"),
+ InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
+ InputParam(
+ "preprocess_kwargs",
+ type_hint=Optional[dict],
+ description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "image_latents",
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image for image-to-image/inpainting generation",
+ )
+ ]
+
+ # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
+ # YiYi TODO: update the _encode_vae_image so that we can use #Coped from
+ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
+ latents_mean = latents_std = None
+ if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
+ latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
+
+ dtype = image.dtype
+ if components.vae.config.force_upcast:
+ image = image.float()
+ components.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
+
+ if components.vae.config.force_upcast:
+ components.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
+ latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
+ image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
+ else:
+ image_latents = components.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
+ block_state.device = components._execution_device
+ block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+
+ image = components.image_processor.preprocess(
+ block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
+ )
+ image = image.to(device=block_state.device, dtype=block_state.dtype)
+ block_state.batch_size = image.shape[0]
+
+ # if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
+ if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
+ f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec(
+ "mask_processor",
+ VaeImageProcessor,
+ config=FrozenDict(
+ {"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}
+ ),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that prepares the image and mask for the inpainting process"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("image", required=True),
+ InputParam("mask_image", required=True),
+ InputParam("padding_mask_crop"),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ InputParam("generator"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
+ ),
+ OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
+ OutputParam(
+ "masked_image_latents",
+ type_hint=torch.Tensor,
+ description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)",
+ ),
+ OutputParam(
+ "crops_coords",
+ type_hint=Optional[Tuple[int, int]],
+ description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
+ ),
+ ]
+
+ # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
+ # YiYi TODO: update the _encode_vae_image so that we can use #Coped from
+ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
+ latents_mean = latents_std = None
+ if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
+ latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
+
+ dtype = image.dtype
+ if components.vae.config.force_upcast:
+ image = image.float()
+ components.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
+
+ if components.vae.config.force_upcast:
+ components.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
+ latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
+ image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
+ else:
+ image_latents = components.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
+ # do not accept do_classifier_free_guidance
+ def prepare_mask_latents(
+ self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+ block_state.device = components._execution_device
+
+ if block_state.height is None:
+ block_state.height = components.default_height
+ if block_state.width is None:
+ block_state.width = components.default_width
+
+ if block_state.padding_mask_crop is not None:
+ block_state.crops_coords = components.mask_processor.get_crop_region(
+ block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
+ )
+ block_state.resize_mode = "fill"
+ else:
+ block_state.crops_coords = None
+ block_state.resize_mode = "default"
+
+ image = components.image_processor.preprocess(
+ block_state.image,
+ height=block_state.height,
+ width=block_state.width,
+ crops_coords=block_state.crops_coords,
+ resize_mode=block_state.resize_mode,
+ )
+ image = image.to(dtype=torch.float32)
+
+ mask = components.mask_processor.preprocess(
+ block_state.mask_image,
+ height=block_state.height,
+ width=block_state.width,
+ resize_mode=block_state.resize_mode,
+ crops_coords=block_state.crops_coords,
+ )
+ block_state.masked_image = image * (mask < 0.5)
+
+ block_state.batch_size = image.shape[0]
+ image = image.to(device=block_state.device, dtype=block_state.dtype)
+ block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
+
+ # 7. Prepare mask latent variables
+ block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
+ components,
+ mask,
+ block_state.masked_image,
+ block_state.batch_size,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
new file mode 100644
index 000000000000..68b5e33755b5
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
@@ -0,0 +1,395 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import (
+ StableDiffusionXLControlNetInputStep,
+ StableDiffusionXLControlNetUnionInputStep,
+ StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
+ StableDiffusionXLImg2ImgPrepareLatentsStep,
+ StableDiffusionXLImg2ImgSetTimestepsStep,
+ StableDiffusionXLInpaintPrepareLatentsStep,
+ StableDiffusionXLInputStep,
+ StableDiffusionXLPrepareAdditionalConditioningStep,
+ StableDiffusionXLPrepareLatentsStep,
+ StableDiffusionXLSetTimestepsStep,
+)
+from .decoders import (
+ StableDiffusionXLDecodeStep,
+ StableDiffusionXLInpaintOverlayMaskStep,
+)
+from .denoise import (
+ StableDiffusionXLControlNetDenoiseStep,
+ StableDiffusionXLDenoiseStep,
+ StableDiffusionXLInpaintControlNetDenoiseStep,
+ StableDiffusionXLInpaintDenoiseStep,
+)
+from .encoders import (
+ StableDiffusionXLInpaintVaeEncoderStep,
+ StableDiffusionXLIPAdapterStep,
+ StableDiffusionXLTextEncoderStep,
+ StableDiffusionXLVaeEncoderStep,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# auto blocks & sequential blocks & mappings
+
+
+# vae encoder (run before before_denoise)
+class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
+ block_names = ["inpaint", "img2img"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n"
+ + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
+ + " - if neither `mask_image` nor `image` is provided, step will be skipped."
+ )
+
+
+# optional ip-adapter (run before input step)
+class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLIPAdapterStep]
+ block_names = ["ip_adapter"]
+ block_trigger_inputs = ["ip_adapter_image"]
+
+ @property
+ def description(self):
+ return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
+
+
+# before_denoise: text2img
+class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLSetTimestepsStep,
+ StableDiffusionXLPrepareLatentsStep,
+ StableDiffusionXLPrepareAdditionalConditioningStep,
+ ]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n"
+ + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
+ )
+
+
+# before_denoise: img2img
+class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLImg2ImgSetTimestepsStep,
+ StableDiffusionXLImg2ImgPrepareLatentsStep,
+ StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
+ ]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
+ )
+
+
+# before_denoise: inpainting
+class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLImg2ImgSetTimestepsStep,
+ StableDiffusionXLInpaintPrepareLatentsStep,
+ StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
+ ]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
+ )
+
+
+# before_denoise: all task (text2img, img2img, inpainting)
+class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLInpaintBeforeDenoiseStep,
+ StableDiffusionXLImg2ImgBeforeDenoiseStep,
+ StableDiffusionXLBeforeDenoiseStep,
+ ]
+ block_names = ["inpaint", "img2img", "text2img"]
+ block_trigger_inputs = ["mask", "image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n"
+ + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n"
+ + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
+ + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n"
+ )
+
+
+# optional controlnet input step (after before_denoise, before denoise)
+# works for both controlnet and controlnet_union
+class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep]
+ block_names = ["controlnet_union", "controlnet"]
+ block_trigger_inputs = ["control_mode", "control_image"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet Input step that prepare the controlnet input.\n"
+ + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n"
+ + " (it should be called right before the denoise step)"
+ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n"
+ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided."
+ + " - if neither `control_mode` nor `control_image` is provided, step will be skipped."
+ )
+
+
+# denoise: controlnet (text2img, img2img, inpainting)
+class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep]
+ block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"]
+ block_trigger_inputs = ["mask", "controlnet_cond"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents with controlnet. "
+ "This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks."
+ "This block should not be used without a controlnet_cond input"
+ " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided."
+ " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided."
+ " - If neither mask nor controlnet_cond are provided, step will be skipped."
+ )
+
+
+# denoise: all task with or without controlnet (text2img, img2img, inpainting)
+class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLAutoControlNetDenoiseStep,
+ StableDiffusionXLInpaintDenoiseStep,
+ StableDiffusionXLDenoiseStep,
+ ]
+ block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["controlnet_cond", "mask", None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. "
+ "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet."
+ " - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)."
+ " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)."
+ " - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)."
+ )
+
+
+# decode: inpaint
+class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
+ block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
+ block_names = ["decode", "mask_overlay"]
+
+ @property
+ def description(self):
+ return (
+ "Inpaint decode step that decode the denoised latents into images outputs.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n"
+ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
+ )
+
+
+# decode: all task (text2img, img2img, inpainting)
+class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
+ block_names = ["inpaint", "non-inpaint"]
+ block_trigger_inputs = ["padding_mask_crop", None]
+
+ @property
+ def description(self):
+ return (
+ "Decode step that decode the denoised latents into images outputs.\n"
+ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n"
+ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n"
+ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."
+ )
+
+
+class StableDiffusionXLCoreDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLInputStep,
+ StableDiffusionXLAutoBeforeDenoiseStep,
+ StableDiffusionXLAutoControlNetInputStep,
+ StableDiffusionXLAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "controlnet_input", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `StableDiffusionXLInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `StableDiffusionXLAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `StableDiffusionXLAutoControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
+ + " - `StableDiffusionXLAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support text-to-image, image-to-image, inpainting, with or without controlnet/controlnet_union/ip_adapter for Stable Diffusion XL:\n"
+ + "- for image-to-image generation, you need to provide `image_latents`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image_latents`\n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ + "- to run the ip_adapter workflow, you need to load ip_adapter into your unet and provide `ip_adapter_embeds`\n"
+ + "- for text-to-image generation, all you need to provide is prompt embeddings\n"
+ )
+
+
+# ip-adapter, controlnet, text2img, img2img, inpainting
+class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLTextEncoderStep,
+ StableDiffusionXLAutoIPAdapterStep,
+ StableDiffusionXLAutoVaeEncoderStep,
+ StableDiffusionXLCoreDenoiseStep,
+ StableDiffusionXLAutoDecodeStep,
+ ]
+ block_names = [
+ "text_encoder",
+ "ip_adapter",
+ "vae_encoder",
+ "denoise",
+ "decode",
+ ]
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n"
+ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`"
+ )
+
+
+# controlnet (input + denoise step)
+class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLAutoControlNetInputStep,
+ StableDiffusionXLAutoControlNetDenoiseStep,
+ ]
+ block_names = ["controlnet_input", "controlnet_denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet auto step that prepare the controlnet input and denoise the latents. "
+ + "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks."
+ + " (it should be replace at 'denoise' step)"
+ )
+
+
+TEXT2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLDenoiseStep),
+ ("decode", StableDiffusionXLDecodeStep),
+ ]
+)
+
+IMAGE2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("vae_encoder", StableDiffusionXLVaeEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLDenoiseStep),
+ ("decode", StableDiffusionXLDecodeStep),
+ ]
+)
+
+INPAINT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("vae_encoder", StableDiffusionXLInpaintVaeEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLInpaintDenoiseStep),
+ ("decode", StableDiffusionXLInpaintDecodeStep),
+ ]
+)
+
+CONTROLNET_BLOCKS = InsertableDict(
+ [
+ ("denoise", StableDiffusionXLAutoControlnetStep),
+ ]
+)
+
+
+IP_ADAPTER_BLOCKS = InsertableDict(
+ [
+ ("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
+ ]
+)
+
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
+ ("vae_encoder", StableDiffusionXLAutoVaeEncoderStep),
+ ("denoise", StableDiffusionXLCoreDenoiseStep),
+ ("decode", StableDiffusionXLAutoDecodeStep),
+ ]
+)
+
+
+ALL_BLOCKS = {
+ "text2img": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "inpaint": INPAINT_BLOCKS,
+ "controlnet": CONTROLNET_BLOCKS,
+ "ip_adapter": IP_ADAPTER_BLOCKS,
+ "auto": AUTO_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
new file mode 100644
index 000000000000..f2a4c96073ea
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
@@ -0,0 +1,364 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...image_processor import PipelineImageInput
+from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from ...pipelines.pipeline_utils import StableDiffusionMixin
+from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from ...utils import logging
+from ..modular_pipeline import ModularPipeline
+from ..modular_pipeline_utils import InputParam, OutputParam
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
+# YiYi Notes: model specific components:
+## (1) it should inherit from ModularPipeline
+## (2) acts like a container that holds components and configs
+## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
+## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
+## (5) how to use together with Components_manager?
+class StableDiffusionXLModularPipeline(
+ ModularPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ ModularIPAdapterMixin,
+):
+ """
+ A ModularPipeline for Stable Diffusion XL.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "StableDiffusionXLAutoBlocks"
+
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ default_sample_size = 128
+ if hasattr(self, "unet") and self.unet is not None:
+ default_sample_size = self.unet.config.sample_size
+ return default_sample_size
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ return vae_scale_factor
+
+ # YiYi TODO: change to num_channels_latents
+ @property
+ def num_channels_unet(self):
+ num_channels_unet = 4
+ if hasattr(self, "unet") and self.unet is not None:
+ num_channels_unet = self.unet.config.in_channels
+ return num_channels_unet
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 4
+ if hasattr(self, "vae") and self.vae is not None:
+ num_channels_latents = self.vae.config.latent_channels
+ return num_channels_latents
+
+
+# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks
+# auto_docstring
+SDXL_INPUTS_SCHEMA = {
+ "prompt": InputParam(
+ "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
+ ),
+ "prompt_2": InputParam(
+ "prompt_2",
+ type_hint=Union[str, List[str]],
+ description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
+ ),
+ "negative_prompt": InputParam(
+ "negative_prompt",
+ type_hint=Union[str, List[str]],
+ description="The prompt or prompts not to guide the image generation",
+ ),
+ "negative_prompt_2": InputParam(
+ "negative_prompt_2",
+ type_hint=Union[str, List[str]],
+ description="The negative prompt or prompts for text_encoder_2",
+ ),
+ "cross_attention_kwargs": InputParam(
+ "cross_attention_kwargs",
+ type_hint=Optional[dict],
+ description="Kwargs dictionary passed to the AttentionProcessor",
+ ),
+ "clip_skip": InputParam(
+ "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
+ ),
+ "image": InputParam(
+ "image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="The image(s) to modify for img2img or inpainting",
+ ),
+ "mask_image": InputParam(
+ "mask_image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="Mask image for inpainting, white pixels will be repainted",
+ ),
+ "generator": InputParam(
+ "generator",
+ type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
+ description="Generator(s) for deterministic generation",
+ ),
+ "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
+ "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
+ "num_images_per_prompt": InputParam(
+ "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
+ ),
+ "num_inference_steps": InputParam(
+ "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
+ ),
+ "timesteps": InputParam(
+ "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
+ ),
+ "sigmas": InputParam(
+ "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
+ ),
+ "denoising_end": InputParam(
+ "denoising_end",
+ type_hint=Optional[float],
+ description="Fraction of denoising process to complete before termination",
+ ),
+ # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
+ "strength": InputParam(
+ "strength", type_hint=float, default=0.3, description="How much to transform the reference image"
+ ),
+ "denoising_start": InputParam(
+ "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
+ ),
+ "latents": InputParam(
+ "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
+ ),
+ "padding_mask_crop": InputParam(
+ "padding_mask_crop",
+ type_hint=Optional[Tuple[int, int]],
+ description="Size of margin in crop for image and mask",
+ ),
+ "original_size": InputParam(
+ "original_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Original size of the image for SDXL's micro-conditioning",
+ ),
+ "target_size": InputParam(
+ "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
+ ),
+ "negative_original_size": InputParam(
+ "negative_original_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Negative conditioning based on image resolution",
+ ),
+ "negative_target_size": InputParam(
+ "negative_target_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Negative conditioning based on target resolution",
+ ),
+ "crops_coords_top_left": InputParam(
+ "crops_coords_top_left",
+ type_hint=Tuple[int, int],
+ default=(0, 0),
+ description="Top-left coordinates for SDXL's micro-conditioning",
+ ),
+ "negative_crops_coords_top_left": InputParam(
+ "negative_crops_coords_top_left",
+ type_hint=Tuple[int, int],
+ default=(0, 0),
+ description="Negative conditioning crop coordinates",
+ ),
+ "aesthetic_score": InputParam(
+ "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
+ ),
+ "negative_aesthetic_score": InputParam(
+ "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
+ ),
+ "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
+ "output_type": InputParam(
+ "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
+ ),
+ "ip_adapter_image": InputParam(
+ "ip_adapter_image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="Image(s) to be used as IP adapter",
+ ),
+ "control_image": InputParam(
+ "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
+ ),
+ "control_guidance_start": InputParam(
+ "control_guidance_start",
+ type_hint=Union[float, List[float]],
+ default=0.0,
+ description="When ControlNet starts applying",
+ ),
+ "control_guidance_end": InputParam(
+ "control_guidance_end",
+ type_hint=Union[float, List[float]],
+ default=1.0,
+ description="When ControlNet stops applying",
+ ),
+ "controlnet_conditioning_scale": InputParam(
+ "controlnet_conditioning_scale",
+ type_hint=Union[float, List[float]],
+ default=1.0,
+ description="Scale factor for ControlNet outputs",
+ ),
+ "guess_mode": InputParam(
+ "guess_mode",
+ type_hint=bool,
+ default=False,
+ description="Enables ControlNet encoder to recognize input without prompts",
+ ),
+ "control_mode": InputParam(
+ "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
+ ),
+ "prompt_embeds": InputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ required=True,
+ description="Text embeddings used to guide image generation",
+ ),
+ "negative_prompt_embeds": InputParam(
+ "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
+ ),
+ "pooled_prompt_embeds": InputParam(
+ "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
+ ),
+ "negative_pooled_prompt_embeds": InputParam(
+ "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
+ ),
+ "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
+ "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
+ "preprocess_kwargs": InputParam(
+ "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
+ ),
+ "latent_timestep": InputParam(
+ "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
+ ),
+ "image_latents": InputParam(
+ "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
+ ),
+ "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
+ "masked_image_latents": InputParam(
+ "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
+ ),
+ "add_time_ids": InputParam(
+ "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
+ ),
+ "negative_add_time_ids": InputParam(
+ "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
+ ),
+ "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
+ "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
+ "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
+ "ip_adapter_embeds": InputParam(
+ "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
+ ),
+ "negative_ip_adapter_embeds": InputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Negative image embeddings for IP-Adapter",
+ ),
+ "images": InputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ required=True,
+ description="Generated images",
+ ),
+}
+
+
+SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
+ "prompt_embeds": OutputParam(
+ "prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"
+ ),
+ "negative_prompt_embeds": OutputParam(
+ "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
+ ),
+ "pooled_prompt_embeds": OutputParam(
+ "pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"
+ ),
+ "negative_pooled_prompt_embeds": OutputParam(
+ "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
+ ),
+ "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
+ "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
+ "image_latents": OutputParam(
+ "image_latents", type_hint=torch.Tensor, description="Latents representing reference image"
+ ),
+ "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
+ "masked_image_latents": OutputParam(
+ "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
+ ),
+ "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
+ "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
+ "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
+ "latent_timestep": OutputParam(
+ "latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"
+ ),
+ "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
+ "negative_add_time_ids": OutputParam(
+ "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
+ ),
+ "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
+ "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
+ "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
+ "ip_adapter_embeds": OutputParam(
+ "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
+ ),
+ "negative_ip_adapter_embeds": OutputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Negative image embeddings for IP-Adapter",
+ ),
+ "images": OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="Generated images",
+ ),
+}
+
+
+SDXL_OUTPUTS_SCHEMA = {
+ "images": OutputParam(
+ "images",
+ type_hint=Union[
+ Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput
+ ],
+ description="The final generated images",
+ )
+}
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py
new file mode 100644
index 000000000000..3e788bf94741
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py
@@ -0,0 +1,99 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+SDXL_NODE_TYPES_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ ],
+ "outputs": [
+ "controlnet_out",
+ ],
+ "block_names": [None],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ # custom adapters coming in as inputs
+ "controlnet",
+ # ip_adapter is optional and custom; include if available
+ "ip_adapter",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ "block_names": ["vae_encoder"],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ "block_names": ["text_encoder"],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ "block_names": ["decode"],
+ },
+}
diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py
new file mode 100644
index 000000000000..73f67c9afed2
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/__init__.py
@@ -0,0 +1,63 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["decoders"] = ["WanImageVaeDecoderStep"]
+ _import_structure["encoders"] = ["WanTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "Wan22AutoBlocks",
+ "WanAutoBlocks",
+ "WanAutoImageEncoderStep",
+ "WanAutoVaeImageEncoderStep",
+ ]
+ _import_structure["modular_pipeline"] = ["WanModularPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .decoders import WanImageVaeDecoderStep
+ from .encoders import WanTextEncoderStep
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ Wan22AutoBlocks,
+ WanAutoBlocks,
+ WanAutoImageEncoderStep,
+ WanAutoVaeImageEncoderStep,
+ )
+ from .modular_pipeline import WanModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py
new file mode 100644
index 000000000000..e2f8d3e7d88b
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/before_denoise.py
@@ -0,0 +1,637 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ...models import WanTransformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import logging
+from ...utils.torch_utils import randn_tensor
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import WanModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
+# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
+# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
+# configuration of guider is.
+
+
+def repeat_tensor_to_batch_size(
+ input_name: str,
+ input_tensor: torch.Tensor,
+ batch_size: int,
+ num_videos_per_prompt: int = 1,
+) -> torch.Tensor:
+ """Repeat tensor elements to match the final batch size.
+
+ This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt)
+ by repeating each element along dimension 0.
+
+ The input tensor must have batch size 1 or batch_size. The function will:
+ - If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times
+ - If batch size equals batch_size: repeat each element num_videos_per_prompt times
+
+ Args:
+ input_name (str): Name of the input tensor (used for error messages)
+ input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
+ batch_size (int): The base batch size (number of prompts)
+ num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1.
+
+ Returns:
+ torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt)
+
+ Raises:
+ ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
+
+ Examples:
+ tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
+ batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
+ [4, 3]
+
+ tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
+ tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
+ - shape: [4, 3]
+ """
+ # make sure input is a tensor
+ if not isinstance(input_tensor, torch.Tensor):
+ raise ValueError(f"`{input_name}` must be a tensor")
+
+ # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
+ if input_tensor.shape[0] == 1:
+ repeat_by = batch_size * num_videos_per_prompt
+ elif input_tensor.shape[0] == batch_size:
+ repeat_by = num_videos_per_prompt
+ else:
+ raise ValueError(
+ f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
+ )
+
+ # expand the tensor to match the batch_size * num_videos_per_prompt
+ input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
+
+ return input_tensor
+
+
+def calculate_dimension_from_latents(
+ latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int
+) -> Tuple[int, int]:
+ """Calculate image dimensions from latent tensor dimensions.
+
+ This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by
+ multiplying the latent num_frames/height/width by the VAE scale factor.
+
+ Args:
+ latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
+ Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
+ vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension.
+ Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension)
+ vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension.
+ Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
+
+ Returns:
+ Tuple[int, int]: The calculated image dimensions as (height, width)
+
+ Raises:
+ ValueError: If latents tensor doesn't have 4 or 5 dimensions
+
+ """
+ if latents.ndim != 5:
+ raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}")
+
+ _, _, num_latent_frames, latent_height, latent_width = latents.shape
+
+ num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1
+ height = latent_height * vae_scale_factor_spatial
+ width = latent_width * vae_scale_factor_spatial
+
+ return num_frames, height, width
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class WanTextInputStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Input processing step that:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
+ "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
+ "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
+ "have a final batch_size of batch_size * num_videos_per_prompt."
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("transformer", WanTransformer3DModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_videos_per_prompt", default=1),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `transformer.dtype`)",
+ ),
+ ]
+
+ def check_inputs(self, components, block_state):
+ if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
+ if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {block_state.negative_prompt_embeds.shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
+ )
+
+ if block_state.negative_prompt_embeds is not None:
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
+ 1, block_state.num_videos_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
+ block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class WanAdditionalInputsStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["first_frame_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
+
+ This step handles multiple common tasks to prepare inputs for the denoising step:
+ 1. For encoded image latents, use it update height/width if None, and expands batch size
+ 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
+
+ This is a dynamic block that allows you to configure which inputs to process.
+
+ Args:
+ image_latent_inputs (List[str], optional): Names of image latent tensors to process.
+ In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
+ a single string or list of strings. Defaults to ["first_frame_latents"].
+ additional_batch_inputs (List[str], optional):
+ Names of additional conditional input tensors to expand batch size. These tensors will only have their
+ batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
+ Defaults to [].
+
+ Examples:
+ # Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
+
+ # Configure to process multiple image latent inputs
+ WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
+
+ # Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
+ image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
+ )
+ """
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ # Functionality section
+ summary_section = (
+ "Input processing step that:\n"
+ " 1. For image latent inputs: Updates height/width if None, and expands batch size\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size"
+ )
+
+ # Inputs info
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_videos_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="num_frames"),
+ ]
+
+ # Add image latent inputs
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ # Add additional batch inputs
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate num_frames, height/width from latents
+ num_frames, height, width = calculate_dimension_from_latents(
+ image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial
+ )
+ block_state.num_frames = block_state.num_frames or num_frames
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_videos_per_prompt=block_state.num_videos_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_videos_per_prompt=block_state.num_videos_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the scheduler's timesteps for inference"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ device = components._execution_device
+
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ components.scheduler,
+ block_state.num_inference_steps,
+ device,
+ block_state.timesteps,
+ block_state.sigmas,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanPrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Prepare latents step that prepares the latents for the text-to-video generation process"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("height", type_hint=int),
+ InputParam("width", type_hint=int),
+ InputParam("num_frames", type_hint=int),
+ InputParam("latents", type_hint=Optional[torch.Tensor]),
+ InputParam("num_videos_per_prompt", type_hint=int, default=1),
+ InputParam("generator"),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.",
+ ),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ )
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
+ block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
+ )
+ if block_state.num_frames is not None and (
+ block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
+ ):
+ raise ValueError(
+ f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp
+ def prepare_latents(
+ comp,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // comp.vae_scale_factor_spatial,
+ int(width) // comp.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ device = components._execution_device
+ dtype = torch.float32 # Wan latents should be torch.float32 for best quality
+
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+ block_state.num_frames = block_state.num_frames or components.default_num_frames
+
+ block_state.latents = self.prepare_latents(
+ components,
+ batch_size=block_state.batch_size * block_state.num_videos_per_prompt,
+ num_channels_latents=components.num_channels_latents,
+ height=block_state.height,
+ width=block_state.width,
+ num_frames=block_state.num_frames,
+ dtype=dtype,
+ device=device,
+ generator=block_state.generator,
+ latents=block_state.latents,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "step that prepares the masked first frame latents and add it to the latent condition"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
+ InputParam("num_frames", type_hint=int),
+ ]
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
+
+ mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
+ mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
+
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(
+ first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
+ )
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(
+ batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
+ )
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
+ block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "step that prepares the masked latents with first and last frames and add it to the latent condition"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
+ InputParam("num_frames", type_hint=int),
+ ]
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
+
+ mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
+ mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
+
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(
+ first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
+ )
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(
+ batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
+ )
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
+ block_state.first_last_frame_latents = torch.concat(
+ [mask_lat_size, block_state.first_last_frame_latents], dim=1
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py
new file mode 100644
index 000000000000..7cec318c1706
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/decoders.py
@@ -0,0 +1,94 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKLWan
+from ...utils import logging
+from ...video_processor import VideoProcessor
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanImageVaeDecoderStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKLWan),
+ ComponentSpec(
+ "video_processor",
+ VideoProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the denoised latents into images"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The denoised latents from the denoising step",
+ )
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "videos",
+ type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
+ description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ vae_dtype = components.vae.dtype
+
+ latents = block_state.latents
+ latents_mean = (
+ torch.tensor(components.vae.config.latents_mean)
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
+ 1, components.vae.config.z_dim, 1, 1, 1
+ ).to(latents.device, latents.dtype)
+ latents = latents / latents_std + latents_mean
+ latents = latents.to(vae_dtype)
+ block_state.videos = components.vae.decode(latents, return_dict=False)[0]
+
+ block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np")
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py
new file mode 100644
index 000000000000..2da36f52da87
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/denoise.py
@@ -0,0 +1,612 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Tuple
+
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...models import WanTransformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import logging
+from ..modular_pipeline import (
+ BlockState,
+ LoopSequentialPipelineBlocks,
+ ModularPipelineBlocks,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam
+from .modular_pipeline import WanModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `WanDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of the model inputs. Can be generated in input step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ block_state.latent_model_input = block_state.latents.to(block_state.dtype)
+ return components, block_state
+
+
+class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `WanDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "first_frame_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
+ ),
+ InputParam(
+ "dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of the model inputs. Can be generated in input step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
+ block_state.dtype
+ )
+ return components, block_state
+
+
+class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `WanDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "first_last_frame_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
+ ),
+ InputParam(
+ "dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of the model inputs. Can be generated in input step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ block_state.latent_model_input = torch.cat(
+ [block_state.latents, block_state.first_last_frame_latents], dim=1
+ ).to(block_state.dtype)
+ return components, block_state
+
+
+class WanLoopDenoiser(ModularPipelineBlocks):
+ model_name = "wan"
+
+ def __init__(
+ self,
+ guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")},
+ ):
+ """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1.
+
+ Args:
+ guider_input_fields: A dictionary that maps each argument expected by the denoiser model
+ (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either:
+
+ - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds",
+ "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and
+ `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
+ 'encoder_hidden_states'.
+ - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward
+ `block_state.image_embeds` for both conditional and unconditional batches.
+ """
+ if not isinstance(guider_input_fields, dict):
+ raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
+ self._guider_input_fields = guider_input_fields
+ super().__init__()
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 5.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", WanTransformer3DModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents with guidance. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `WanDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ inputs = [
+ InputParam("attention_kwargs"),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+ guider_input_names = []
+ for value in self._guider_input_fields.values():
+ if isinstance(value, tuple):
+ guider_input_names.extend(value)
+ else:
+ guider_input_names.append(value)
+
+ for name in guider_input_names:
+ inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
+ return inputs
+
+ @torch.no_grad()
+ def __call__(
+ self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
+ ) -> PipelineState:
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
+
+ # run the denoiser for each guidance batch
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.transformer)
+ cond_kwargs = guider_state_batch.as_dict()
+ cond_kwargs = {
+ k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
+ for k, v in cond_kwargs.items()
+ if k in self._guider_input_fields.keys()
+ }
+
+ # Predict the noise residual
+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
+ guider_state_batch.noise_pred = components.transformer(
+ hidden_states=block_state.latent_model_input.to(block_state.dtype),
+ timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0]
+ components.guider.cleanup_models(components.transformer)
+
+ # Perform guidance
+ block_state.noise_pred = components.guider(guider_state)[0]
+
+ return components, block_state
+
+
+class Wan22LoopDenoiser(ModularPipelineBlocks):
+ model_name = "wan"
+
+ def __init__(
+ self,
+ guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")},
+ ):
+ """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2.
+
+ Args:
+ guider_input_fields: A dictionary that maps each argument expected by the denoiser model
+ (for example, "encoder_hidden_states") to data stored on `block_state`. The value can be either:
+
+ - A tuple of strings. For instance, `{"encoder_hidden_states": ("prompt_embeds",
+ "negative_prompt_embeds")}` tells the guider to read `block_state.prompt_embeds` and
+ `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
+ `encoder_hidden_states`.
+ - A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider forward
+ `block_state.image_embeds` for both conditional and unconditional batches.
+ """
+ if not isinstance(guider_input_fields, dict):
+ raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
+ self._guider_input_fields = guider_input_fields
+ super().__init__()
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec(
+ "guider_2",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 3.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", WanTransformer3DModel),
+ ComponentSpec("transformer_2", WanTransformer3DModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents with guidance. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `WanDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="boundary_ratio",
+ default=0.875,
+ description="The boundary ratio to divide the denoising loop into high noise and low noise stages.",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ inputs = [
+ InputParam("attention_kwargs"),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+ guider_input_names = []
+ for value in self._guider_input_fields.values():
+ if isinstance(value, tuple):
+ guider_input_names.extend(value)
+ else:
+ guider_input_names.append(value)
+
+ for name in guider_input_names:
+ inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
+ return inputs
+
+ @torch.no_grad()
+ def __call__(
+ self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
+ ) -> PipelineState:
+ boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps
+ if t >= boundary_timestep:
+ block_state.current_model = components.transformer
+ block_state.guider = components.guider
+ else:
+ block_state.current_model = components.transformer_2
+ block_state.guider = components.guider_2
+
+ block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
+
+ # run the denoiser for each guidance batch
+ for guider_state_batch in guider_state:
+ block_state.guider.prepare_models(block_state.current_model)
+ cond_kwargs = guider_state_batch.as_dict()
+ cond_kwargs = {
+ k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
+ for k, v in cond_kwargs.items()
+ if k in self._guider_input_fields.keys()
+ }
+
+ # Predict the noise residual
+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
+ guider_state_batch.noise_pred = block_state.current_model(
+ hidden_states=block_state.latent_model_input.to(block_state.dtype),
+ timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0]
+ block_state.guider.cleanup_models(block_state.current_model)
+
+ # Perform guidance
+ block_state.noise_pred = block_state.guider(guider_state)[0]
+
+ return components, block_state
+
+
+class WanLoopAfterDenoiser(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that update the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `WanDenoiseLoopWrapper`)"
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # Perform scheduler step using the predicted output
+ latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred.float(),
+ t,
+ block_state.latents.float(),
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != latents_dtype:
+ block_state.latents = block_state.latents.to(latents_dtype)
+
+ return components, block_state
+
+
+class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
+ ]
+
+ @property
+ def loop_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class WanDenoiseStep(WanDenoiseLoopWrapper):
+ block_classes = [
+ WanLoopBeforeDenoiser,
+ WanLoopDenoiser(
+ guider_input_fields={
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ }
+ ),
+ WanLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `WanLoopBeforeDenoiser`\n"
+ " - `WanLoopDenoiser`\n"
+ " - `WanLoopAfterDenoiser`\n"
+ "This block supports text-to-video tasks for wan2.1."
+ )
+
+
+class Wan22DenoiseStep(WanDenoiseLoopWrapper):
+ block_classes = [
+ WanLoopBeforeDenoiser,
+ Wan22LoopDenoiser(
+ guider_input_fields={
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ }
+ ),
+ WanLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `WanLoopBeforeDenoiser`\n"
+ " - `Wan22LoopDenoiser`\n"
+ " - `WanLoopAfterDenoiser`\n"
+ "This block supports text-to-video tasks for Wan2.2."
+ )
+
+
+class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper):
+ block_classes = [
+ WanImage2VideoLoopBeforeDenoiser,
+ WanLoopDenoiser(
+ guider_input_fields={
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ "encoder_hidden_states_image": "image_embeds",
+ }
+ ),
+ WanLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `WanImage2VideoLoopBeforeDenoiser`\n"
+ " - `WanLoopDenoiser`\n"
+ " - `WanLoopAfterDenoiser`\n"
+ "This block supports image-to-video tasks for wan2.1."
+ )
+
+
+class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
+ block_classes = [
+ WanImage2VideoLoopBeforeDenoiser,
+ Wan22LoopDenoiser(
+ guider_input_fields={
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ }
+ ),
+ WanLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `WanImage2VideoLoopBeforeDenoiser`\n"
+ " - `WanLoopDenoiser`\n"
+ " - `WanLoopAfterDenoiser`\n"
+ "This block supports image-to-video tasks for Wan2.2."
+ )
+
+
+class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
+ block_classes = [
+ WanFLF2VLoopBeforeDenoiser,
+ WanLoopDenoiser(
+ guider_input_fields={
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ "encoder_hidden_states_image": "image_embeds",
+ }
+ ),
+ WanLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `WanFLF2VLoopBeforeDenoiser`\n"
+ " - `WanLoopDenoiser`\n"
+ " - `WanLoopAfterDenoiser`\n"
+ "This block supports FLF2V tasks for wan2.1."
+ )
diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py
new file mode 100644
index 000000000000..dc49df8eab8c
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/encoders.py
@@ -0,0 +1,667 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import List, Optional, Union
+
+import numpy as np
+import PIL
+import regex as re
+import torch
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...image_processor import PipelineImageInput
+from ...models import AutoencoderKLWan
+from ...utils import is_ftfy_available, is_torchvision_available, logging
+from ...video_processor import VideoProcessor
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import WanModularPipeline
+
+
+if is_ftfy_available():
+ import ftfy
+
+if is_torchvision_available():
+ from torchvision import transforms
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+def get_t5_prompt_embeds(
+ text_encoder: UMT5EncoderModel,
+ tokenizer: AutoTokenizer,
+ prompt: Union[str, List[str]],
+ max_sequence_length: int,
+ device: torch.device,
+):
+ dtype = text_encoder.dtype
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+ prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ return prompt_embeds
+
+
+def encode_image(
+ image: PipelineImageInput,
+ image_processor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModel,
+ device: Optional[torch.device] = None,
+):
+ image = image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def encode_vae_image(
+ video_tensor: torch.Tensor,
+ vae: AutoencoderKLWan,
+ generator: torch.Generator,
+ device: torch.device,
+ dtype: torch.dtype,
+ latent_channels: int = 16,
+):
+ if not isinstance(video_tensor, torch.Tensor):
+ raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.")
+
+ if isinstance(generator, list) and len(generator) != video_tensor.shape[0]:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}."
+ )
+
+ video_tensor = video_tensor.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list):
+ video_latents = [
+ retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(video_tensor.shape[0])
+ ]
+ video_latents = torch.cat(video_latents, dim=0)
+ else:
+ video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax")
+
+ latents_mean = (
+ torch.tensor(vae.config.latents_mean)
+ .view(1, latent_channels, 1, 1, 1)
+ .to(video_latents.device, video_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, latent_channels, 1, 1, 1).to(
+ video_latents.device, video_latents.dtype
+ )
+ video_latents = (video_latents - latents_mean) * latents_std
+
+ return video_latents
+
+
+class WanTextEncoderStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the video generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", UMT5EncoderModel),
+ ComponentSpec("tokenizer", AutoTokenizer),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 5.0}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("prompt"),
+ InputParam("negative_prompt"),
+ InputParam("max_sequence_length", default=512),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="negative text embeddings used to guide the image generation",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(block_state):
+ if block_state.prompt is not None and (
+ not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
+ ):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
+
+ @staticmethod
+ def encode_prompt(
+ components,
+ prompt: str,
+ device: Optional[torch.device] = None,
+ prepare_unconditional_embeds: bool = True,
+ negative_prompt: Optional[str] = None,
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ prepare_unconditional_embeds (`bool`):
+ whether to use prepare unconditional embeddings or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum number of text tokens to be used for the generation process.
+ """
+ device = device or components._execution_device
+ if not isinstance(prompt, list):
+ prompt = [prompt]
+ batch_size = len(prompt)
+
+ prompt_embeds = get_t5_prompt_embeds(
+ text_encoder=components.text_encoder,
+ tokenizer=components.tokenizer,
+ prompt=prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if prepare_unconditional_embeds:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = get_t5_prompt_embeds(
+ text_encoder=components.text_encoder,
+ tokenizer=components.tokenizer,
+ prompt=negative_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ # Get inputs and intermediates
+ block_state = self.get_block_state(state)
+ self.check_inputs(block_state)
+
+ block_state.device = components._execution_device
+
+ # Encode input prompt
+ (
+ block_state.prompt_embeds,
+ block_state.negative_prompt_embeds,
+ ) = self.encode_prompt(
+ components=components,
+ prompt=block_state.prompt,
+ device=block_state.device,
+ prepare_unconditional_embeds=components.requires_unconditional_embeds,
+ negative_prompt=block_state.negative_prompt,
+ max_sequence_length=block_state.max_sequence_length,
+ )
+
+ # Add outputs
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanImageResizeStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("image", type_hint=PIL.Image.Image, required=True),
+ InputParam("height", type_hint=int, default=480),
+ InputParam("width", type_hint=int, default=832),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("resized_image", type_hint=PIL.Image.Image),
+ ]
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ max_area = block_state.height * block_state.width
+
+ image = block_state.image
+ aspect_ratio = image.height / image.width
+ mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial
+ block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ block_state.resized_image = image.resize((block_state.width, block_state.height))
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanImageCropResizeStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Image Resize step that resize the last_image to the same size of first frame image with center crop."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image"
+ ),
+ InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("resized_last_image", type_hint=PIL.Image.Image),
+ ]
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ height = block_state.resized_image.height
+ width = block_state.resized_image.width
+ image = block_state.last_image
+
+ # Calculate resize ratio to match first frame dimensions
+ resize_ratio = max(width / image.width, height / image.height)
+
+ # Resize the image
+ width = round(image.width * resize_ratio)
+ height = round(image.height * resize_ratio)
+ size = [width, height]
+ resized_image = transforms.functional.center_crop(image, size)
+ block_state.resized_last_image = resized_image
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanImageEncoderStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Image Encoder step that generate image_embeds based on first frame image to guide the video generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("image_processor", CLIPImageProcessor),
+ ComponentSpec("image_encoder", CLIPVisionModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
+ ]
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+
+ image = block_state.resized_image
+
+ image_embeds = encode_image(
+ image_processor=components.image_processor,
+ image_encoder=components.image_encoder,
+ image=image,
+ device=device,
+ )
+ block_state.image_embeds = image_embeds
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Image Encoder step that generate image_embeds based on first and last frame images to guide the video generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("image_processor", CLIPImageProcessor),
+ ComponentSpec("image_encoder", CLIPVisionModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
+ InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
+ ]
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+
+ first_frame_image = block_state.resized_image
+ last_frame_image = block_state.resized_last_image
+
+ image_embeds = encode_image(
+ image_processor=components.image_processor,
+ image_encoder=components.image_encoder,
+ image=[first_frame_image, last_frame_image],
+ device=device,
+ )
+ block_state.image_embeds = image_embeds
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanVaeImageEncoderStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Vae Image Encoder step that generate condition_latents based on first frame image to guide the video generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKLWan),
+ ComponentSpec(
+ "video_processor",
+ VideoProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("num_frames"),
+ InputParam("generator"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "first_frame_latents",
+ type_hint=torch.Tensor,
+ description="video latent representation with the first frame image condition",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
+ block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
+ )
+ if block_state.num_frames is not None and (
+ block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
+ ):
+ raise ValueError(
+ f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
+ )
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ image = block_state.resized_image
+
+ device = components._execution_device
+ dtype = torch.float32
+
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ num_frames = block_state.num_frames or components.default_num_frames
+
+ image_tensor = components.video_processor.preprocess(image, height=height, width=width).to(
+ device=device, dtype=dtype
+ )
+
+ if image_tensor.dim() == 4:
+ image_tensor = image_tensor.unsqueeze(2)
+
+ video_tensor = torch.cat(
+ [
+ image_tensor,
+ image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width),
+ ],
+ dim=2,
+ ).to(device=device, dtype=dtype)
+
+ block_state.first_frame_latents = encode_vae_image(
+ video_tensor=video_tensor,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Vae Image Encoder step that generate condition_latents based on first and last frame images to guide the video generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKLWan),
+ ComponentSpec(
+ "video_processor",
+ VideoProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
+ InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("num_frames"),
+ InputParam("generator"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "first_last_frame_latents",
+ type_hint=torch.Tensor,
+ description="video latent representation with the first and last frame images condition",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
+ block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
+ )
+ if block_state.num_frames is not None and (
+ block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
+ ):
+ raise ValueError(
+ f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
+ )
+
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ first_frame_image = block_state.resized_image
+ last_frame_image = block_state.resized_last_image
+
+ device = components._execution_device
+ dtype = torch.float32
+
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ num_frames = block_state.num_frames or components.default_num_frames
+
+ first_image_tensor = components.video_processor.preprocess(first_frame_image, height=height, width=width).to(
+ device=device, dtype=dtype
+ )
+ first_image_tensor = first_image_tensor.unsqueeze(2)
+
+ last_image_tensor = components.video_processor.preprocess(last_frame_image, height=height, width=width).to(
+ device=device, dtype=dtype
+ )
+
+ last_image_tensor = last_image_tensor.unsqueeze(2)
+
+ video_tensor = torch.cat(
+ [
+ first_image_tensor,
+ first_image_tensor.new_zeros(
+ first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width
+ ),
+ last_image_tensor,
+ ],
+ dim=2,
+ ).to(device=device, dtype=dtype)
+
+ block_state.first_last_frame_latents = encode_vae_image(
+ video_tensor=video_tensor,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py
new file mode 100644
index 000000000000..b3b70b2f9be1
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py
@@ -0,0 +1,474 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import (
+ WanAdditionalInputsStep,
+ WanPrepareFirstFrameLatentsStep,
+ WanPrepareFirstLastFrameLatentsStep,
+ WanPrepareLatentsStep,
+ WanSetTimestepsStep,
+ WanTextInputStep,
+)
+from .decoders import WanImageVaeDecoderStep
+from .denoise import (
+ Wan22DenoiseStep,
+ Wan22Image2VideoDenoiseStep,
+ WanDenoiseStep,
+ WanFLF2VDenoiseStep,
+ WanImage2VideoDenoiseStep,
+)
+from .encoders import (
+ WanFirstLastFrameImageEncoderStep,
+ WanFirstLastFrameVaeImageEncoderStep,
+ WanImageCropResizeStep,
+ WanImageEncoderStep,
+ WanImageResizeStep,
+ WanTextEncoderStep,
+ WanVaeImageEncoderStep,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# wan2.1
+# wan2.1: text2vid
+class WanCoreDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ WanTextInputStep,
+ WanSetTimestepsStep,
+ WanPrepareLatentsStep,
+ WanDenoiseStep,
+ ]
+ block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "denoise block that takes encoded conditions and runs the denoising process.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ + " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `WanDenoiseStep` is used to denoise the latents\n"
+ )
+
+
+# wan2.1: image2video
+## image encoder
+class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
+ model_name = "wan"
+ block_classes = [WanImageResizeStep, WanImageEncoderStep]
+ block_names = ["image_resize", "image_encoder"]
+
+ @property
+ def description(self):
+ return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
+
+
+## vae encoder
+class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
+ model_name = "wan"
+ block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
+ block_names = ["image_resize", "vae_image_encoder"]
+
+ @property
+ def description(self):
+ return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
+
+
+## denoise
+class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ WanTextInputStep,
+ WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
+ WanSetTimestepsStep,
+ WanPrepareLatentsStep,
+ WanPrepareFirstFrameLatentsStep,
+ WanImage2VideoDenoiseStep,
+ ]
+ block_names = [
+ "input",
+ "additional_inputs",
+ "set_timesteps",
+ "prepare_latents",
+ "prepare_first_frame_latents",
+ "denoise",
+ ]
+
+ @property
+ def description(self):
+ return (
+ "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ + " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ + " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
+ )
+
+
+# wan2.1: FLF2v
+
+
+## image encoder
+class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
+ model_name = "wan"
+ block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
+ block_names = ["image_resize", "last_image_resize", "image_encoder"]
+
+ @property
+ def description(self):
+ return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
+
+
+## vae encoder
+class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
+ model_name = "wan"
+ block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
+ block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
+
+ @property
+ def description(self):
+ return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
+
+
+## denoise
+class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ WanTextInputStep,
+ WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
+ WanSetTimestepsStep,
+ WanPrepareLatentsStep,
+ WanPrepareFirstLastFrameLatentsStep,
+ WanFLF2VDenoiseStep,
+ ]
+ block_names = [
+ "input",
+ "additional_inputs",
+ "set_timesteps",
+ "prepare_latents",
+ "prepare_first_last_frame_latents",
+ "denoise",
+ ]
+
+ @property
+ def description(self):
+ return (
+ "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ + " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ + " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
+ + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
+ )
+
+
+# wan2.1: auto blocks
+## image encoder
+class WanAutoImageEncoderStep(AutoPipelineBlocks):
+ block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
+ block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
+ block_trigger_inputs = ["last_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Image Encoder step that encode the image to generate the image embeddings"
+ + "This is an auto pipeline block that works for image2video tasks."
+ + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ + " - if `last_image` or `image` is not provided, step will be skipped."
+ )
+
+
+## vae encoder
+class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
+ block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
+ block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
+ block_trigger_inputs = ["last_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae Image Encoder step that encode the image to generate the image latents"
+ + "This is an auto pipeline block that works for image2video tasks."
+ + " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
+ + " - if `last_image` or `image` is not provided, step will be skipped."
+ )
+
+
+## denoise
+class WanAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ WanFLF2VCoreDenoiseStep,
+ WanImage2VideoCoreDenoiseStep,
+ WanCoreDenoiseStep,
+ ]
+ block_names = ["flf2v", "image2video", "text2video"]
+ block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. "
+ "This is a auto pipeline block that works for text2video and image2video tasks."
+ " - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
+ " - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
+ + " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
+ + " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
+ )
+
+
+# auto pipeline blocks
+class WanAutoBlocks(SequentialPipelineBlocks):
+ block_classes = [
+ WanTextEncoderStep,
+ WanAutoImageEncoderStep,
+ WanAutoVaeImageEncoderStep,
+ WanAutoDenoiseStep,
+ WanImageVaeDecoderStep,
+ ]
+ block_names = [
+ "text_encoder",
+ "image_encoder",
+ "vae_image_encoder",
+ "denoise",
+ "decode",
+ ]
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-video using Wan.\n"
+ + "- for text-to-video generation, all you need to provide is `prompt`"
+ )
+
+
+# wan22
+# wan2.2: text2vid
+
+
+## denoise
+class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ WanTextInputStep,
+ WanSetTimestepsStep,
+ WanPrepareLatentsStep,
+ Wan22DenoiseStep,
+ ]
+ block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "denoise block that takes encoded conditions and runs the denoising process.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ + " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
+ )
+
+
+# wan2.2: image2video
+## denoise
+class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ WanTextInputStep,
+ WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
+ WanSetTimestepsStep,
+ WanPrepareLatentsStep,
+ WanPrepareFirstFrameLatentsStep,
+ Wan22Image2VideoDenoiseStep,
+ ]
+ block_names = [
+ "input",
+ "additional_inputs",
+ "set_timesteps",
+ "prepare_latents",
+ "prepare_first_frame_latents",
+ "denoise",
+ ]
+
+ @property
+ def description(self):
+ return (
+ "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ + " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ + " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
+ )
+
+
+class Wan22AutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ Wan22Image2VideoCoreDenoiseStep,
+ Wan22CoreDenoiseStep,
+ ]
+ block_names = ["image2video", "text2video"]
+ block_trigger_inputs = ["first_frame_latents", None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. "
+ "This is a auto pipeline block that works for text2video and image2video tasks."
+ " - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
+ " - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
+ + " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
+ + " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
+ )
+
+
+class Wan22AutoBlocks(SequentialPipelineBlocks):
+ block_classes = [
+ WanTextEncoderStep,
+ WanAutoVaeImageEncoderStep,
+ Wan22AutoDenoiseStep,
+ WanImageVaeDecoderStep,
+ ]
+ block_names = [
+ "text_encoder",
+ "vae_image_encoder",
+ "denoise",
+ "decode",
+ ]
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-video using Wan2.2.\n"
+ + "- for text-to-video generation, all you need to provide is `prompt`"
+ )
+
+
+# presets for wan2.1 and wan2.2
+# YiYi Notes: should we move these to doc?
+# wan2.1
+TEXT2VIDEO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", WanTextEncoderStep),
+ ("input", WanTextInputStep),
+ ("set_timesteps", WanSetTimestepsStep),
+ ("prepare_latents", WanPrepareLatentsStep),
+ ("denoise", WanDenoiseStep),
+ ("decode", WanImageVaeDecoderStep),
+ ]
+)
+
+IMAGE2VIDEO_BLOCKS = InsertableDict(
+ [
+ ("image_resize", WanImageResizeStep),
+ ("image_encoder", WanImage2VideoImageEncoderStep),
+ ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
+ ("input", WanTextInputStep),
+ ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
+ ("set_timesteps", WanSetTimestepsStep),
+ ("prepare_latents", WanPrepareLatentsStep),
+ ("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
+ ("denoise", WanImage2VideoDenoiseStep),
+ ("decode", WanImageVaeDecoderStep),
+ ]
+)
+
+
+FLF2V_BLOCKS = InsertableDict(
+ [
+ ("image_resize", WanImageResizeStep),
+ ("last_image_resize", WanImageCropResizeStep),
+ ("image_encoder", WanFLF2VImageEncoderStep),
+ ("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
+ ("input", WanTextInputStep),
+ ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
+ ("set_timesteps", WanSetTimestepsStep),
+ ("prepare_latents", WanPrepareLatentsStep),
+ ("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
+ ("denoise", WanFLF2VDenoiseStep),
+ ("decode", WanImageVaeDecoderStep),
+ ]
+)
+
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", WanTextEncoderStep),
+ ("image_encoder", WanAutoImageEncoderStep),
+ ("vae_image_encoder", WanAutoVaeImageEncoderStep),
+ ("denoise", WanAutoDenoiseStep),
+ ("decode", WanImageVaeDecoderStep),
+ ]
+)
+
+# wan2.2 presets
+
+TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
+ [
+ ("text_encoder", WanTextEncoderStep),
+ ("input", WanTextInputStep),
+ ("set_timesteps", WanSetTimestepsStep),
+ ("prepare_latents", WanPrepareLatentsStep),
+ ("denoise", Wan22DenoiseStep),
+ ("decode", WanImageVaeDecoderStep),
+ ]
+)
+
+IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
+ [
+ ("image_resize", WanImageResizeStep),
+ ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
+ ("input", WanTextInputStep),
+ ("set_timesteps", WanSetTimestepsStep),
+ ("prepare_latents", WanPrepareLatentsStep),
+ ("denoise", Wan22DenoiseStep),
+ ("decode", WanImageVaeDecoderStep),
+ ]
+)
+
+AUTO_BLOCKS_WAN22 = InsertableDict(
+ [
+ ("text_encoder", WanTextEncoderStep),
+ ("vae_image_encoder", WanAutoVaeImageEncoderStep),
+ ("denoise", Wan22AutoDenoiseStep),
+ ("decode", WanImageVaeDecoderStep),
+ ]
+)
+
+# presets all blocks (wan and wan22)
+
+
+ALL_BLOCKS = {
+ "wan2.1": {
+ "text2video": TEXT2VIDEO_BLOCKS,
+ "image2video": IMAGE2VIDEO_BLOCKS,
+ "flf2v": FLF2V_BLOCKS,
+ "auto": AUTO_BLOCKS,
+ },
+ "wan2.2": {
+ "text2video": TEXT2VIDEO_BLOCKS_WAN22,
+ "image2video": IMAGE2VIDEO_BLOCKS_WAN22,
+ "auto": AUTO_BLOCKS_WAN22,
+ },
+}
diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py
new file mode 100644
index 000000000000..930b25e4b905
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py
@@ -0,0 +1,120 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Dict, Optional
+
+from ...loaders import WanLoraLoaderMixin
+from ...pipelines.pipeline_utils import StableDiffusionMixin
+from ...utils import logging
+from ..modular_pipeline import ModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanModularPipeline(
+ ModularPipeline,
+ StableDiffusionMixin,
+ WanLoraLoaderMixin,
+):
+ """
+ A ModularPipeline for Wan.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "WanAutoBlocks"
+
+ # override the default_blocks_name in base class, which is just return self.default_blocks_name
+ def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
+ if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
+ return "Wan22AutoBlocks"
+ else:
+ return "WanAutoBlocks"
+
+ @property
+ def default_height(self):
+ return self.default_sample_height * self.vae_scale_factor_spatial
+
+ @property
+ def default_width(self):
+ return self.default_sample_width * self.vae_scale_factor_spatial
+
+ @property
+ def default_num_frames(self):
+ return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1
+
+ @property
+ def default_sample_height(self):
+ return 60
+
+ @property
+ def default_sample_width(self):
+ return 104
+
+ @property
+ def default_sample_num_frames(self):
+ return 21
+
+ @property
+ def patch_size_spatial(self):
+ patch_size_spatial = 2
+ if hasattr(self, "transformer") and self.transformer is not None:
+ patch_size_spatial = self.transformer.config.patch_size[1]
+ return patch_size_spatial
+
+ @property
+ def vae_scale_factor_spatial(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def vae_scale_factor_temporal(self):
+ vae_scale_factor = 4
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** sum(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def num_channels_transformer(self):
+ num_channels_transformer = 16
+ if hasattr(self, "transformer") and self.transformer is not None:
+ num_channels_transformer = self.transformer.config.in_channels
+ return num_channels_transformer
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if hasattr(self, "vae") and self.vae is not None:
+ num_channels_latents = self.vae.config.z_dim
+ return num_channels_latents
+
+ @property
+ def requires_unconditional_embeds(self):
+ requires_unconditional_embeds = False
+
+ if hasattr(self, "guider") and self.guider is not None:
+ requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
+
+ return requires_unconditional_embeds
+
+ @property
+ def num_train_timesteps(self):
+ num_train_timesteps = 1000
+ if hasattr(self, "scheduler") and self.scheduler is not None:
+ num_train_timesteps = self.scheduler.config.num_train_timesteps
+ return num_train_timesteps
diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md
index b2954c07438b..6f9ab7b291ad 100644
--- a/src/diffusers/pipelines/README.md
+++ b/src/diffusers/pipelines/README.md
@@ -16,7 +16,7 @@ or created independently from each other.
To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
More specifically, we strive to provide pipelines that
-- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
+- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752)),
- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).
@@ -33,17 +33,17 @@ available a colab notebook to directly try them out.
| Pipeline | Source | Tasks | Colab
|-------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:---:|:---:|
| [dance diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/Harmonai-org/sample-generator) | *Unconditional Audio Generation* |
-| [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | *Unconditional Image Generation* |
-| [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | *Unconditional Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
-| [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Text-to-Image Generation* |
-| [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | *Unconditional Image Generation* |
-| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* |
+| [ddpm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://huggingface.co/papers/2006.11239) | *Unconditional Image Generation* |
+| [ddim](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://huggingface.co/papers/2010.02502) | *Unconditional Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
+| [latent_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://huggingface.co/papers/2112.10752) | *Text-to-Image Generation* |
+| [latent_diffusion_uncond](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://huggingface.co/papers/2112.10752) | *Unconditional Image Generation* |
+| [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://huggingface.co/papers/2202.09778) | *Unconditional Image Generation* |
| [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* |
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
-| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* |
+| [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://huggingface.co/papers/2206.00364) | *Unconditional Image Generation* |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below.
@@ -86,7 +86,7 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
### Text-to-Image generation with Stable Diffusion
```python
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
@@ -159,7 +159,7 @@ init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
pipe = StableDiffusionInpaintPipeline.from_pretrained(
- "runwayml/stable-diffusion-inpainting",
+ "stable-diffusion-v1-5/stable-diffusion-inpainting",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 90c85547d652..df9e698a961f 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -10,6 +10,7 @@
is_librosa_available,
is_note_seq_available,
is_onnx_available,
+ is_opencv_available,
is_sentencepiece_available,
is_torch_available,
is_torch_npu_available,
@@ -133,6 +134,9 @@
"AnimateDiffVideoToVideoPipeline",
"AnimateDiffVideoToVideoControlNetPipeline",
]
+ _import_structure["bria"] = ["BriaPipeline"]
+ _import_structure["bria_fibo"] = ["BriaFiboPipeline"]
+ _import_structure["flux2"] = ["Flux2Pipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -146,7 +150,10 @@
"FluxFillPipeline",
"FluxPriorReduxPipeline",
"ReduxImageEncoder",
+ "FluxKontextPipeline",
+ "FluxKontextInpaintPipeline",
]
+ _import_structure["prx"] = ["PRXPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -154,6 +161,7 @@
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
+ _import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
@@ -163,6 +171,12 @@
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
+ _import_structure["cosmos"] = [
+ "Cosmos2TextToImagePipeline",
+ "CosmosTextToWorldPipeline",
+ "CosmosVideoToWorldPipeline",
+ "Cosmos2VideoToWorldPipeline",
+ ]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -228,12 +242,16 @@
"EasyAnimateInpaintPipeline",
"EasyAnimateControlPipeline",
]
+ _import_structure["hidream_image"] = ["HiDreamImagePipeline"]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = [
"HunyuanVideoPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
+ "HunyuanVideoFramepackPipeline",
]
+ _import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline", "HunyuanVideo15ImageToVideoPipeline"]
+ _import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
@@ -275,9 +293,11 @@
"LTXPipeline",
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
+ "LTXLatentUpsamplePipeline",
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
+ _import_structure["lucy"] = ["LucyEditPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
@@ -288,6 +308,8 @@
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
+ _import_structure["ovis_image"] = ["OvisImagePipeline"]
+ _import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -295,6 +317,11 @@
"SanaPipeline",
"SanaSprintPipeline",
"SanaControlNetPipeline",
+ "SanaSprintImg2ImgPipeline",
+ ]
+ _import_structure["sana_video"] = [
+ "SanaVideoPipeline",
+ "SanaImageToVideoPipeline",
]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
@@ -377,7 +404,34 @@
"WanPipeline",
"WanImageToVideoPipeline",
"WanVideoToVideoPipeline",
+ "WanVACEPipeline",
+ "WanAnimatePipeline",
+ ]
+ _import_structure["kandinsky5"] = [
+ "Kandinsky5T2VPipeline",
+ "Kandinsky5I2VPipeline",
+ "Kandinsky5T2IPipeline",
+ "Kandinsky5I2IPipeline",
+ ]
+ _import_structure["z_image"] = ["ZImageImg2ImgPipeline", "ZImagePipeline"]
+ _import_structure["skyreels_v2"] = [
+ "SkyReelsV2DiffusionForcingPipeline",
+ "SkyReelsV2DiffusionForcingImageToVideoPipeline",
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline",
+ "SkyReelsV2ImageToVideoPipeline",
+ "SkyReelsV2Pipeline",
]
+ _import_structure["qwenimage"] = [
+ "QwenImagePipeline",
+ "QwenImageImg2ImgPipeline",
+ "QwenImageInpaintPipeline",
+ "QwenImageEditPipeline",
+ "QwenImageEditPlusPipeline",
+ "QwenImageEditInpaintPipeline",
+ "QwenImageControlNetInpaintPipeline",
+ "QwenImageControlNetPipeline",
+ ]
+ _import_structure["chronoedit"] = ["ChronoEditPipeline"]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -447,6 +501,18 @@
"KolorsImg2ImgPipeline",
]
+try:
+ if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils import (
+ dummy_torch_and_transformers_and_opencv_objects,
+ )
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
+else:
+ _import_structure["consisid"] = ["ConsisIDPipeline"]
+
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
@@ -543,6 +609,10 @@
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
+ from .bria import BriaPipeline
+ from .bria_fibo import BriaFiboPipeline
+ from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
+ from .chronoedit import ChronoEditPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
@@ -551,7 +621,6 @@
)
from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
- from .consisid import ConsisIDPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
@@ -573,6 +642,12 @@
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
)
+ from .cosmos import (
+ Cosmos2TextToImagePipeline,
+ Cosmos2VideoToWorldPipeline,
+ CosmosTextToWorldPipeline,
+ CosmosVideoToWorldPipeline,
+ )
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
@@ -610,15 +685,22 @@
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextInpaintPipeline,
+ FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
+ from .flux2 import Flux2Pipeline
+ from .hidream_image import HiDreamImagePipeline
+ from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
+ HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
)
+ from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
@@ -642,7 +724,16 @@
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
)
- from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
+ from .kandinsky3 import (
+ Kandinsky3Img2ImgPipeline,
+ Kandinsky3Pipeline,
+ )
+ from .kandinsky5 import (
+ Kandinsky5I2IPipeline,
+ Kandinsky5I2VPipeline,
+ Kandinsky5T2IPipeline,
+ Kandinsky5T2VPipeline,
+ )
from .latent_consistency_models import (
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
@@ -655,7 +746,8 @@
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
- from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
+ from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
+ from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
@@ -666,6 +758,7 @@
from .mochi import MochiPipeline
from .musicldm import MusicLDMPipeline
from .omnigen import OmniGenPipeline
+ from .ovis_image import OvisImagePipeline
from .pag import (
AnimateDiffPAGPipeline,
HunyuanDiTPAGPipeline,
@@ -688,7 +781,24 @@
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
- from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintPipeline
+ from .prx import PRXPipeline
+ from .qwenimage import (
+ QwenImageControlNetInpaintPipeline,
+ QwenImageControlNetPipeline,
+ QwenImageEditInpaintPipeline,
+ QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
+ QwenImageImg2ImgPipeline,
+ QwenImageInpaintPipeline,
+ QwenImagePipeline,
+ )
+ from .sana import (
+ SanaControlNetPipeline,
+ SanaPipeline,
+ SanaSprintImg2ImgPipeline,
+ SanaSprintPipeline,
+ )
+ from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
@@ -751,12 +861,20 @@
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
- from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
+ from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
+ from .wan import (
+ WanAnimatePipeline,
+ WanImageToVideoPipeline,
+ WanPipeline,
+ WanVACEPipeline,
+ WanVideoToVideoPipeline,
+ )
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
+ from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
try:
if not is_onnx_available():
@@ -812,6 +930,14 @@
else:
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
+ try:
+ if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ..utils.dummy_torch_and_transformers_and_opencv_objects import *
+ else:
+ from .consisid import ConsisIDPipeline
+
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
@@ -847,6 +973,14 @@
else:
from .deprecated import MidiProcessor, SpectrogramDiffusionPipeline
+ from .skyreels_v2 import (
+ SkyReelsV2DiffusionForcingImageToVideoPipeline,
+ SkyReelsV2DiffusionForcingPipeline,
+ SkyReelsV2DiffusionForcingVideoToVideoPipeline,
+ SkyReelsV2ImageToVideoPipeline,
+ SkyReelsV2Pipeline,
+ )
+
else:
import sys
diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py
index cb36a7a672de..3be0129088fb 100644
--- a/src/diffusers/pipelines/allegro/pipeline_allegro.py
+++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The RhymesAI and The HuggingFace Team.
+# Copyright 2025 The RhymesAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -351,7 +351,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -514,7 +514,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -651,6 +651,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -658,6 +664,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -666,6 +678,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -673,6 +691,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
@@ -738,11 +762,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
- usually at the expense of lower video quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to
+ the text `prompt`, usually at the expense of lower video quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
num_frames: (`int`, *optional*, defaults to 88):
@@ -752,15 +776,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated video.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -833,7 +857,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/amused/pipeline_amused.py b/src/diffusers/pipelines/amused/pipeline_amused.py
index 12f7dc7c59d4..131e34d1a4a1 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import is_torch_xla_available, replace_example_docstring
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
@@ -47,7 +47,8 @@
"""
-class AmusedPipeline(DiffusionPipeline):
+class AmusedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
+ _last_supported_version = "0.33.1"
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
@@ -131,7 +132,7 @@ def __call__(
generation deterministic.
latents (`torch.IntTensor`, *optional*):
Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
- gneration. If not provided, the starting latents will be completely masked.
+ generation. If not provided, the starting latents will be completely masked.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. A single vector from the
@@ -160,10 +161,10 @@ def __call__(
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
The targeted aesthetic score according to the laion aesthetic classifier. See
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
- https://arxiv.org/abs/2307.01952.
+ https://huggingface.co/papers/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
The targeted height, width crop coordinates. See the micro-conditioning section of
- https://arxiv.org/abs/2307.01952.
+ https://huggingface.co/papers/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
diff --git a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
index 7ac05b39c3a8..a122c12236dd 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import is_torch_xla_available, replace_example_docstring
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
@@ -57,7 +57,8 @@
"""
-class AmusedImg2ImgPipeline(DiffusionPipeline):
+class AmusedImg2ImgPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
+ _last_supported_version = "0.33.1"
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
@@ -179,10 +180,10 @@ def __call__(
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
The targeted aesthetic score according to the laion aesthetic classifier. See
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
- https://arxiv.org/abs/2307.01952.
+ https://huggingface.co/papers/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
The targeted height, width crop coordinates. See the micro-conditioning section of
- https://arxiv.org/abs/2307.01952.
+ https://huggingface.co/papers/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
diff --git a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
index d908c32745c2..f4bd4944ff9a 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
from ...utils import is_torch_xla_available, replace_example_docstring
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
@@ -65,7 +65,8 @@
"""
-class AmusedInpaintPipeline(DiffusionPipeline):
+class AmusedInpaintPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
+ _last_supported_version = "0.33.1"
image_processor: VaeImageProcessor
vqvae: VQModel
tokenizer: CLIPTokenizer
@@ -203,10 +204,10 @@ def __call__(
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
The targeted aesthetic score according to the laion aesthetic classifier. See
https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
- https://arxiv.org/abs/2307.01952.
+ https://huggingface.co/papers/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
The targeted height, width crop coordinates. See the micro-conditioning section of
- https://arxiv.org/abs/2307.01952.
+ https://huggingface.co/papers/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index d3ad5cc13ce3..091b6db713ba 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -428,7 +428,7 @@ def decode_latents(self, latents, decode_chunk_size: int = 16):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -514,7 +514,7 @@ def check_inputs(
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
- # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
+ # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
@@ -552,7 +552,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -621,8 +621,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
index db546398643b..70180ccf0650 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -472,7 +472,7 @@ def decode_latents(self, latents, decode_chunk_size: int = 16):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -630,7 +630,7 @@ def check_inputs(
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
- # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
+ # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
@@ -700,7 +700,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -772,8 +772,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
index 958eb5fb5134..56d319027595 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -125,7 +125,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -652,7 +652,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -844,7 +844,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -948,11 +948,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower video quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower video quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the video generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -963,15 +963,15 @@ def __call__(
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1003,9 +1003,10 @@ def __call__(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1252,7 +1253,7 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
)
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
index 8c51fddcd5fc..46d650efe8b6 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -140,7 +140,7 @@ class AnimateDiffSparseControlNetPipeline(
):
r"""
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
- to Text-to-Video Diffusion Models](https://arxiv.org/abs/2311.16933).
+ to Text-to-Video Diffusion Models](https://huggingface.co/papers/2311.16933).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -475,7 +475,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -695,7 +695,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -762,8 +762,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
index 116397055272..6f3a609aba4a 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -539,7 +539,7 @@ def decode_latents(self, latents, decode_chunk_size: int = 16):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -725,7 +725,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -805,8 +805,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
index ce974094936a..b00f344598ad 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -571,7 +571,7 @@ def decode_latents(self, latents, decode_chunk_size: int = 16):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -890,7 +890,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -975,8 +975,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
index 14c6d44fc586..6a70f00c76c7 100644
--- a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
+++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
@@ -57,7 +57,7 @@
"""
-class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
+class AudioLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
r"""
Pipeline for text-to-audio generation using AudioLDM.
@@ -81,6 +81,7 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
Vocoder of class `SpeechT5HifiGan`.
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "text_encoder->unet->vae"
def __init__(
@@ -261,7 +262,7 @@ def mel_spectrogram_to_waveform(self, mel_spectrogram):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -397,8 +398,8 @@ def __call__(
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
The number of waveforms to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -472,7 +473,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
index 00bed864ba34..878f6f08db42 100644
--- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,15 +17,14 @@
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...models.activations import get_activation
+from ...models.attention import AttentionMixin
from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
@@ -164,7 +163,7 @@ def forward(
)
-class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+class AudioLDM2UNet2DConditionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional
@@ -246,16 +245,21 @@ def __init__(
out_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
- down_block_types: Tuple[str] = (
+ down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ up_block_types: Tuple[str, ...] = (
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ ),
only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -531,66 +535,6 @@ def __init__(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index b8b5d07af529..452fc3c01b27 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -1,4 +1,4 @@
-# Copyright 2024 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
+# Copyright 2025 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@
from transformers import (
ClapFeatureExtractor,
ClapModel,
- GPT2Model,
+ GPT2LMHeadModel,
RobertaTokenizer,
RobertaTokenizerFast,
SpeechT5HifiGan,
@@ -34,13 +34,15 @@
from ...models import AutoencoderKL
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
+ deprecate,
is_accelerate_available,
is_accelerate_version,
is_librosa_available,
logging,
replace_example_docstring,
)
-from ...utils.torch_utils import randn_tensor
+from ...utils.import_utils import is_transformers_version
+from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
@@ -196,7 +198,7 @@ def __init__(
text_encoder: ClapModel,
text_encoder_2: Union[T5EncoderModel, VitsModel],
projection_model: AudioLDM2ProjectionModel,
- language_model: GPT2Model,
+ language_model: GPT2LMHeadModel,
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
feature_extractor: ClapFeatureExtractor,
@@ -227,6 +229,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
@@ -235,6 +243,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
@@ -259,13 +273,14 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
)
device_type = torch_device.type
- device = torch.device(f"{device_type}:{gpu_id or torch_device.index}")
+ device_str = device_type
+ if gpu_id or torch_device.index:
+ device_str = f"{device_str}:{gpu_id or torch_device.index}"
+ device = torch.device(device_str)
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
- device_mod = getattr(torch, device.type, None)
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
- device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+ empty_device_cache(device.type)
model_sequence = [
self.text_encoder.text_model,
@@ -309,16 +324,26 @@ def generate_language_model(
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
The sequence of generated hidden-states.
"""
+ cache_position_kwargs = {}
+ if is_transformers_version("<", "4.52.1"):
+ cache_position_kwargs["input_ids"] = inputs_embeds
+ else:
+ cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
+ cache_position_kwargs["device"] = (
+ self.language_model.device if getattr(self, "language_model", None) is not None else self.device
+ )
+ cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
- model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
+ model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
+
for _ in range(max_new_tokens):
# prepare model inputs
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
# forward pass to get next hidden states
- output = self.language_model(**model_inputs, return_dict=True)
+ output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
- next_hidden_states = output.last_hidden_state
+ next_hidden_states = output.hidden_states[-1]
# Update the model input
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
@@ -370,7 +395,7 @@ def encode_prompt(
*e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
`negative_prompt` input argument.
generated_prompt_embeds (`torch.Tensor`, *optional*):
- Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
+ Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs,
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
argument.
negative_generated_prompt_embeds (`torch.Tensor`, *optional*):
@@ -391,7 +416,7 @@ def encode_prompt(
attention_mask (`torch.LongTensor`):
Attention mask to be applied to the `prompt_embeds`.
generated_prompt_embeds (`torch.Tensor`):
- Text embeddings generated from the GPT2 langauge model.
+ Text embeddings generated from the GPT2 language model.
Example:
@@ -698,7 +723,7 @@ def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -788,7 +813,7 @@ def check_inputs(
if transcription is None:
if self.text_encoder_2.config.model_type == "vits":
- raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
+ raise ValueError("Cannot forward without transcription. Please make sure to have transcription")
elif transcription is not None and (
not isinstance(transcription, str) and not isinstance(transcription, list)
):
@@ -885,8 +910,8 @@ def __call__(
generated waveforms based on their cosine similarity with the text input in the joint text-audio
embedding space.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -901,7 +926,7 @@ def __call__(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
generated_prompt_embeds (`torch.Tensor`, *optional*):
- Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
+ Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs,
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
argument.
negative_generated_prompt_embeds (`torch.Tensor`, *optional*):
@@ -984,7 +1009,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
index ea60e66d2db9..bb9884e41381 100644
--- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
+++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
@@ -1,4 +1,4 @@
-# Copyright 2024 AuraFlow Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 AuraFlow Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,17 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5Tokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
+from ...loaders import AuraFlowLoraLoaderMixin
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
-from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -112,7 +120,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class AuraFlowPipeline(DiffusionPipeline):
+class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
r"""
Args:
tokenizer (`T5TokenizerFast`):
@@ -233,6 +241,7 @@ def encode_prompt(
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
+ lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -259,10 +268,20 @@ def encode_prompt(
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
if device is None:
device = self._execution_device
-
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -346,6 +365,11 @@ def encode_prompt(
negative_prompt_embeds = None
negative_prompt_attention_mask = None
+ if self.text_encoder is not None:
+ if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
@@ -382,27 +406,21 @@ def prepare_latents(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@property
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
@property
def num_timesteps(self):
return self._num_timesteps
@@ -428,6 +446,7 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -455,11 +474,11 @@ def __call__(
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -468,7 +487,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -486,6 +505,10 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -520,6 +543,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
@@ -530,9 +554,10 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -553,6 +578,7 @@ def __call__(
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -594,6 +620,7 @@ def __call__(
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
+ attention_kwargs=self.attention_kwargs,
)[0]
# perform guidance
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index 6a5f6098b6fb..db0268a2a73d 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -21,6 +21,7 @@
from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
+from .chroma import ChromaPipeline
from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .controlnet import (
@@ -48,6 +49,7 @@
FluxControlPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextPipeline,
FluxPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
@@ -89,6 +91,15 @@
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
+from .qwenimage import (
+ QwenImageControlNetPipeline,
+ QwenImageEditInpaintPipeline,
+ QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
+ QwenImageImg2ImgPipeline,
+ QwenImageInpaintPipeline,
+ QwenImagePipeline,
+)
from .sana import SanaPipeline
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
from .stable_diffusion import (
@@ -106,7 +117,9 @@
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
+from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
+from .z_image import ZImageImg2ImgPipeline, ZImagePipeline
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -141,11 +154,16 @@
("flux", FluxPipeline),
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
+ ("flux-kontext", FluxKontextPipeline),
("lumina", LuminaPipeline),
("lumina2", Lumina2Pipeline),
+ ("chroma", ChromaPipeline),
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("cogview4-control", CogView4ControlPipeline),
+ ("qwenimage", QwenImagePipeline),
+ ("qwenimage-controlnet", QwenImageControlNetPipeline),
+ ("z-image", ZImagePipeline),
]
)
@@ -169,6 +187,11 @@
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
+ ("flux-kontext", FluxKontextPipeline),
+ ("qwenimage", QwenImageImg2ImgPipeline),
+ ("qwenimage-edit", QwenImageEditPipeline),
+ ("qwenimage-edit-plus", QwenImageEditPlusPipeline),
+ ("z-image", ZImageImg2ImgPipeline),
]
)
@@ -190,6 +213,26 @@
("flux-controlnet", FluxControlNetInpaintPipeline),
("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
+ ("qwenimage", QwenImageInpaintPipeline),
+ ("qwenimage-edit", QwenImageEditInpaintPipeline),
+ ]
+)
+
+AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
+ [
+ ("wan", WanPipeline),
+ ]
+)
+
+AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
+ [
+ ("wan", WanImageToVideoPipeline),
+ ]
+)
+
+AUTO_VIDEO2VIDEO_PIPELINES_MAPPING = OrderedDict(
+ [
+ ("wan", WanVideoToVideoPipeline),
]
)
@@ -226,6 +269,9 @@
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
+ AUTO_TEXT2VIDEO_PIPELINES_MAPPING,
+ AUTO_IMAGE2VIDEO_PIPELINES_MAPPING,
+ AUTO_VIDEO2VIDEO_PIPELINES_MAPPING,
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING,
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING,
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING,
@@ -246,14 +292,15 @@ def _get_connected_pipeline(pipeline_cls):
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
-def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
- def get_model(pipeline_class_name):
- for task_mapping in SUPPORTED_TASKS_MAPPINGS:
- for model_name, pipeline in task_mapping.items():
- if pipeline.__name__ == pipeline_class_name:
- return model_name
+def _get_model(pipeline_class_name):
+ for task_mapping in SUPPORTED_TASKS_MAPPINGS:
+ for model_name, pipeline in task_mapping.items():
+ if pipeline.__name__ == pipeline_class_name:
+ return model_name
+
- model_name = get_model(pipeline_class_name)
+def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
+ model_name = _get_model(pipeline_class_name)
if model_name is not None:
task_class = mapping.get(model_name, None)
@@ -322,9 +369,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
saved using
[`~DiffusionPipeline.save_pretrained`].
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
- dtype is automatically derived from the model's weights.
+ torch_dtype (`torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -388,12 +434,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -619,8 +661,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
saved using
[`~DiffusionPipeline.save_pretrained`].
torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
- dtype is automatically derived from the model's weights.
+ Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -684,12 +725,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -930,8 +967,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
saved using
[`~DiffusionPipeline.save_pretrained`].
torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
- dtype is automatically derived from the model's weights.
+ Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -995,12 +1031,8 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
index d2408417f590..b061ac2636a5 100644
--- a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
+++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@
from typing import Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import nn
from transformers import BertTokenizer
from transformers.activations import QuickGELUActivation as QuickGELU
diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
index d29dddf64b01..1b0342ce7a56 100644
--- a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
+++ b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
@@ -1,5 +1,5 @@
-# Copyright 2024 Salesforce.com, inc.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 Salesforce.com, inc.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
index cbd8bef67945..705d930b59fe 100644
--- a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
+++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
@@ -1,5 +1,5 @@
-# Copyright 2024 Salesforce.com, inc.
-# Copyright 2024 The HuggingFace Team. All rights reserved.#
+# Copyright 2025 Salesforce.com, inc.
+# Copyright 2025 The HuggingFace Team. All rights reserved.#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@@ -19,13 +19,9 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
@@ -81,7 +77,7 @@
"""
-class BlipDiffusionPipeline(DiffusionPipeline):
+class BlipDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
@@ -107,6 +103,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
@@ -138,7 +135,7 @@ def __init__(
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
- # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
+ # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
@@ -227,13 +224,13 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by random sampling.
+ tensor will be generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
diff --git a/src/diffusers/pipelines/bria/__init__.py b/src/diffusers/pipelines/bria/__init__.py
new file mode 100644
index 000000000000..60e319ac7910
--- /dev/null
+++ b/src/diffusers/pipelines/bria/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_bria"] = ["BriaPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_bria import BriaPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py
new file mode 100644
index 000000000000..a22a756005ac
--- /dev/null
+++ b/src/diffusers/pipelines/bria/pipeline_bria.py
@@ -0,0 +1,729 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin
+from ...models import AutoencoderKL
+from ...models.transformers.transformer_bria import BriaTransformer2DModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.bria.pipeline_output import BriaPipelineOutput
+from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
+from ...schedulers import (
+ DDIMScheduler,
+ EulerAncestralDiscreteScheduler,
+ FlowMatchEulerDiscreteScheduler,
+ KarrasDiffusionSchedulers,
+)
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import BriaPipeline
+
+ >>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ # BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32.
+
+ >>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.bfloat16)
+ >>> for block in pipe.text_encoder.encoder.block:
+ ... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
+ # BRIA's VAE is not supported in mixed precision, so we use float32.
+
+ >>> if pipe.vae.config.shift_factor == 0:
+ ... pipe.vae.to(dtype=torch.float32)
+
+ >>> prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus."
+ >>> image = pipe(prompt).images[0]
+ >>> image.save("bria.png")
+ ```
+"""
+
+
+def is_ng_none(negative_prompt):
+ return (
+ negative_prompt is None
+ or negative_prompt == ""
+ or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
+ or (type(negative_prompt) == list and negative_prompt[0] == "")
+ )
+
+
+def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ sigmas = timesteps / num_train_timesteps
+
+ inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
+ new_sigmas = sigmas[inds]
+ return new_sigmas
+
+
+class BriaPipeline(DiffusionPipeline):
+ r"""
+ Based on FluxPipeline with several changes:
+ - no pooled embeddings
+ - We use zero padding for prompts
+ - No guidance embedding since this is not a distilled version
+
+ Args:
+ transformer ([`BriaTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. Bria uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ transformer: BriaTransformer2DModel,
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
+ vae: AutoencoderKL,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
+
+ if self.vae.config.shift_factor is None:
+ self.vae.config.shift_factor = 0
+ self.vae.to(dtype=torch.float32)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 128,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ if not is_ng_none(negative_prompt):
+ negative_prompt = (
+ batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ else:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device)
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
+
+ return prompt_embeds, negative_prompt_embeds, text_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @attention_kwargs.setter
+ def attention_kwargs(self, value):
+ self._attention_kwargs = value
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ ):
+ tokenizer = self.tokenizer
+ text_encoder = self.text_encoder
+ device = device or text_encoder.device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+ prompt_embeds_list = []
+ for p in prompt:
+ text_inputs = tokenizer(
+ p,
+ # padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ # Concat zeros to max_sequence
+ b, seq_len, dim = prompt_embeds.shape
+ if seq_len < max_sequence_length:
+ padding = torch.zeros(
+ (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
+ )
+ prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
+ prompt_embeds = prompt_embeds.to(device=device)
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1)
+ prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
+ return prompt_embeds
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
+ latent_image_ids = latent_image_ids.reshape(
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 30,
+ timesteps: List[int] = None,
+ guidance_scale: float = 5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ clip_value: Union[None, float] = None,
+ normalize: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.bria.BriaPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.bria.BriaPipelineOutput`] or `tuple`: [`~pipelines.bria.BriaPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self.attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ if (
+ isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
+ and self.scheduler.config["use_dynamic_shifting"]
+ ):
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ else:
+ # 4. Prepare timesteps
+ # Sample from training sigmas
+ if isinstance(self.scheduler, DDIMScheduler) or isinstance(
+ self.scheduler, EulerAncestralDiscreteScheduler
+ ):
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, None, None
+ )
+ else:
+ sigmas = get_original_sigmas(
+ num_train_timesteps=self.scheduler.config.num_train_timesteps,
+ num_inference_steps=num_inference_steps,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ if len(latent_image_ids.shape) == 3:
+ latent_image_ids = latent_image_ids[0]
+ if len(text_ids.shape) == 3:
+ text_ids = text_ids[0]
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # This is predicts "v" from flow-matching or eps from diffusion
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ cfg_noise_pred_text = noise_pred_text.std()
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if normalize:
+ noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
+
+ if clip_value:
+ assert clip_value > 0
+ noise_pred = noise_pred.clip(-clip_value, clip_value)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return BriaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/bria/pipeline_output.py b/src/diffusers/pipelines/bria/pipeline_output.py
new file mode 100644
index 000000000000..54eed0623371
--- /dev/null
+++ b/src/diffusers/pipelines/bria/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class BriaPipelineOutput(BaseOutput):
+ """
+ Output class for Bria pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py
new file mode 100644
index 000000000000..206a463b394b
--- /dev/null
+++ b/src/diffusers/pipelines/bria_fibo/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_bria_fibo import BriaFiboPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
new file mode 100644
index 000000000000..8fd29756b290
--- /dev/null
+++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
@@ -0,0 +1,838 @@
+# Copyright (c) Bria.ai. All rights reserved.
+#
+# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
+# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
+#
+# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
+# indicate if changes were made, and do not use the material for commercial purposes.
+#
+# See the license for further details.
+
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin
+from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
+from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
+from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
+from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Example:
+ ```python
+ import torch
+ from diffusers import BriaFiboPipeline
+ from diffusers.modular_pipelines import ModularPipeline
+
+ torch.set_grad_enabled(False)
+ vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
+
+ pipe = BriaFiboPipeline.from_pretrained(
+ "briaai/FIBO",
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+ )
+ pipe.enable_model_cpu_offload()
+
+ with torch.inference_mode():
+ # 1. Create a prompt to generate an initial image
+ output = vlm_pipe(prompt="a beautiful dog")
+ json_prompt_generate = output.values["json_prompt"]
+
+ # Generate the image from the structured json prompt
+ results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
+ results_generate.images[0].save("image_generate.png")
+ ```
+"""
+
+
+class BriaFiboPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
+ r"""
+ Args:
+ transformer (`BriaFiboTransformer2DModel`):
+ The transformer model for 2D diffusion modeling.
+ scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
+ Scheduler to be used with `transformer` to denoise the encoded latents.
+ vae (`AutoencoderKLWan`):
+ Variational Auto-Encoder for encoding and decoding images to and from latent representations.
+ text_encoder (`SmolLM3ForCausalLM`):
+ Text encoder for processing input prompts.
+ tokenizer (`AutoTokenizer`):
+ Tokenizer used for processing the input text prompts for the text_encoder.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ transformer: BriaFiboTransformer2DModel,
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
+ vae: AutoencoderKLWan,
+ text_encoder: SmolLM3ForCausalLM,
+ tokenizer: AutoTokenizer,
+ ):
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor = 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.default_sample_size = 64
+
+ def get_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 2048,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if not prompt:
+ raise ValueError("`prompt` must be a non-empty string or list of strings.")
+
+ batch_size = len(prompt)
+ bot_token_id = 128000
+
+ text_encoder_device = device if device is not None else torch.device("cpu")
+ if not isinstance(text_encoder_device, torch.device):
+ text_encoder_device = torch.device(text_encoder_device)
+
+ if all(p == "" for p in prompt):
+ input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
+ attention_mask = torch.ones_like(input_ids)
+ else:
+ tokenized = self.tokenizer(
+ prompt,
+ padding="longest",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ input_ids = tokenized.input_ids.to(text_encoder_device)
+ attention_mask = tokenized.attention_mask.to(text_encoder_device)
+
+ if any(p == "" for p in prompt):
+ empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
+ input_ids[empty_rows] = bot_token_id
+ attention_mask[empty_rows] = 1
+
+ encoder_outputs = self.text_encoder(
+ input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_outputs.hidden_states
+
+ prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ hidden_states = tuple(
+ layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
+ )
+ attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
+
+ return prompt_embeds, hidden_states, attention_mask
+
+ @staticmethod
+ def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
+ # Pad embeddings to `max_tokens` while preserving the mask of real tokens.
+ batch_size, seq_len, dim = prompt_embeds.shape
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
+ else:
+ attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
+
+ if max_tokens < seq_len:
+ raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
+
+ if max_tokens > seq_len:
+ pad_length = max_tokens - seq_len
+ padding = torch.zeros(
+ (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
+ )
+ prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
+
+ mask_padding = torch.zeros(
+ (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
+ )
+ attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
+
+ return prompt_embeds, attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ guidance_scale: float = 5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 3000,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ guidance_scale (`float`):
+ Guidance scale for classifier free guidance.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ prompt_attention_mask = None
+ negative_prompt_attention_mask = None
+ if prompt_embeds is None:
+ prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
+ prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
+
+ if guidance_scale > 1:
+ if isinstance(negative_prompt, list) and negative_prompt[0] is None:
+ negative_prompt = ""
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
+ negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ # Pad to longest
+ if prompt_attention_mask is not None:
+ prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
+
+ if negative_prompt_embeds is not None:
+ if negative_prompt_attention_mask is not None:
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(
+ device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
+ )
+ max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
+
+ prompt_embeds, prompt_attention_mask = self.pad_embedding(
+ prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
+ )
+ prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
+ negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
+ else:
+ max_tokens = prompt_embeds.shape[1]
+ prompt_embeds, prompt_attention_mask = self.pad_embedding(
+ prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
+ )
+ negative_prompt_layers = None
+
+ dtype = self.text_encoder.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
+
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ prompt_layers,
+ negative_prompt_layers,
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @staticmethod
+ # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels)
+ latents = latents.permute(0, 3, 1, 2)
+
+ return latents
+
+ @staticmethod
+ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.permute(0, 2, 3, 1)
+ latents = latents.reshape(batch_size, height * width, num_channels_latents)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ do_patching=False,
+ ):
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ if do_patching:
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ else:
+ latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ @staticmethod
+ def _prepare_attention_mask(attention_mask):
+ attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
+
+ # convert to 0 - keep, -inf ignore
+ attention_matrix = torch.where(
+ attention_matrix == 1, 0.0, -torch.inf
+ ) # Apply -inf to ignored tokens for nulling softmax score
+ return attention_matrix
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 30,
+ timesteps: List[int] = None,
+ guidance_scale: float = 5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 3000,
+ do_patching=False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
+ do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
+ Examples:
+ Returns:
+ [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ prompt_layers,
+ negative_prompt_layers,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ guidance_scale=guidance_scale,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ num_images_per_prompt=num_images_per_prompt,
+ lora_scale=lora_scale,
+ )
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if guidance_scale > 1:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_layers = [
+ torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
+ ]
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
+ self.transformer.single_transformer_blocks
+ )
+ if len(prompt_layers) >= total_num_layers_transformer:
+ # remove first layers
+ prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
+ else:
+ # duplicate last layer
+ prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
+
+ # 5. Prepare latent variables
+
+ num_channels_latents = self.transformer.config.in_channels
+ if do_patching:
+ num_channels_latents = int(num_channels_latents / 4)
+
+ latents, latent_image_ids = self.prepare_latents(
+ prompt_batch_size,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ do_patching,
+ )
+
+ latent_attention_mask = torch.ones(
+ [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
+ )
+ if guidance_scale > 1:
+ latent_attention_mask = latent_attention_mask.repeat(2, 1)
+
+ attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
+ attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
+ attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
+
+ if self._joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+ self._joint_attention_kwargs["attention_mask"] = attention_mask
+
+ # Adapt scheduler to dynamic shifting (resolution dependent)
+
+ if do_patching:
+ seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
+ else:
+ seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+
+ mu = calculate_shift(
+ seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+
+ # Init sigmas and timesteps according to shift size
+ # This changes the scheduler in-place according to the dynamic scheduling
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps=num_inference_steps,
+ device=device,
+ timesteps=None,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Support old different diffusers versions
+ if len(latent_image_ids.shape) == 3:
+ latent_image_ids = latent_image_ids[0]
+
+ if len(text_ids.shape) == 3:
+ text_ids = text_ids[0]
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(
+ device=latent_model_input.device, dtype=latent_model_input.dtype
+ )
+
+ # This is predicts "v" from flow-matching or eps from diffusion
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ text_encoder_layers=prompt_layers,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ )[0]
+
+ # perform guidance
+ if guidance_scale > 1:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ if do_patching:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ else:
+ latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
+
+ latents = latents.unsqueeze(dim=2)
+ latents_device = latents[0].device
+ latents_dtype = latents[0].dtype
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents_device, latents_dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents_device, latents_dtype
+ )
+ latents_scaled = [latent / latents_std + latents_mean for latent in latents]
+ latents_scaled = torch.cat(latents_scaled, dim=0)
+ image = []
+ for scaled_latent in latents_scaled:
+ curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
+ curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
+ image.append(curr_image)
+ if len(image) == 1:
+ image = image[0]
+ else:
+ image = np.stack(image, axis=0)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return BriaFiboPipelineOutput(images=image)
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 3000:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_output.py b/src/diffusers/pipelines/bria_fibo/pipeline_output.py
new file mode 100644
index 000000000000..f459185a2c7c
--- /dev/null
+++ b/src/diffusers/pipelines/bria_fibo/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class BriaFiboPipelineOutput(BaseOutput):
+ """
+ Output class for BriaFibo pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/chroma/__init__.py b/src/diffusers/pipelines/chroma/__init__.py
new file mode 100644
index 000000000000..d9238b735c41
--- /dev/null
+++ b/src/diffusers/pipelines/chroma/__init__.py
@@ -0,0 +1,49 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["ChromaPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_chroma"] = ["ChromaPipeline"]
+ _import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_chroma import ChromaPipeline
+ from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py
new file mode 100644
index 000000000000..ed6c2c2105b6
--- /dev/null
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py
@@ -0,0 +1,976 @@
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ChromaTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import ChromaPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import ChromaPipeline
+
+ >>> model_id = "lodestones/Chroma1-HD"
+ >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
+ >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
+ >>> pipe = ChromaPipeline.from_pretrained(
+ ... model_id,
+ ... transformer=transformer,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+ >>> prompt = [
+ ... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
+ ... ]
+ >>> negative_prompt = [
+ ... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
+ ... ]
+ >>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
+ >>> image.save("chroma.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class ChromaPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Chroma pipeline for text-to-image generation.
+
+ Reference: https://huggingface.co/lodestones/Chroma1-HD/
+
+ Args:
+ transformer ([`ChromaTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: ChromaTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.default_sample_size = 128
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ tokenizer_mask = text_inputs.attention_mask
+
+ tokenizer_mask_device = tokenizer_mask.to(device)
+
+ # unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=False,
+ attention_mask=tokenizer_mask_device,
+ )[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
+ seq_lengths = tokenizer_mask_device.sum(dim=1)
+ mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
+ attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ attention_mask = attention_mask.repeat(1, num_images_per_prompt)
+ attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ do_classifier_free_guidance: bool = True,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ negative_text_ids = None
+
+ if do_classifier_free_guidance:
+ if negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = (
+ batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return (
+ prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_text_ids,
+ negative_prompt_attention_mask,
+ )
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_embeds=None,
+ negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError(
+ "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ def _prepare_attention_mask(
+ self,
+ batch_size,
+ sequence_length,
+ dtype,
+ attention_mask=None,
+ ):
+ if attention_mask is None:
+ return attention_mask
+
+ # Extend the prompt attention mask to account for image tokens in the final sequence
+ attention_mask = torch.cat(
+ [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
+ dim=1,
+ )
+
+ return attention_mask
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 35,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 5.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ not greater than `1`).
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (torch.Tensor, *optional*):
+ Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
+ Chroma requires a single padding token remain unmasked. Please refer to
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
+ negative_prompt_attention_mask (torch.Tensor, *optional*):
+ Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
+ prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_embeds=negative_prompt_embeds,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_text_ids,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+
+ attention_mask = self._prepare_attention_mask(
+ batch_size=latents.shape[0],
+ sequence_length=image_seq_len,
+ dtype=latents.dtype,
+ attention_mask=prompt_attention_mask,
+ )
+ negative_attention_mask = self._prepare_attention_mask(
+ batch_size=latents.shape[0],
+ sequence_length=image_seq_len,
+ dtype=latents.dtype,
+ attention_mask=negative_prompt_attention_mask,
+ )
+
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ attention_mask=attention_mask,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ attention_mask=negative_attention_mask,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ChromaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
new file mode 100644
index 000000000000..470c746e4146
--- /dev/null
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
@@ -0,0 +1,1060 @@
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ChromaTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import ChromaPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
+
+ >>> model_id = "lodestones/Chroma1-HD"
+ >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
+ >>> pipe = ChromaImg2ImgPipeline.from_pretrained(
+ ... model_id,
+ ... transformer=transformer,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+ >>> init_image = load_image(
+ ... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ ... )
+ >>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
+ >>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
+ >>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
+ >>> image.save("chroma-img2img.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class ChromaImg2ImgPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Chroma pipeline for image-to-image generation.
+
+ Reference: https://huggingface.co/lodestones/Chroma1-HD/
+
+ Args:
+ transformer ([`ChromaTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representation
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: ChromaTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.default_sample_size = 128
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ tokenizer_mask = text_inputs.attention_mask
+
+ tokenizer_mask_device = tokenizer_mask.to(device)
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ output_hidden_states=False,
+ attention_mask=tokenizer_mask_device,
+ )[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ seq_lengths = tokenizer_mask_device.sum(dim=1)
+ mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
+ attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ attention_mask = attention_mask.repeat(1, num_images_per_prompt)
+ attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ do_classifier_free_guidance: bool = True,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ negative_text_ids = None
+
+ if do_classifier_free_guidance:
+ if negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = (
+ batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return (
+ prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_text_ids,
+ negative_prompt_attention_mask,
+ )
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ device = device or self._execution_device
+
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ strength,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError(
+ "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask"
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(height // 2, width // 2, device, dtype)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ return latents, latent_image_ids
+
+ def _prepare_attention_mask(
+ self,
+ batch_size,
+ sequence_length,
+ dtype,
+ attention_mask=None,
+ ):
+ if attention_mask is None:
+ return attention_mask
+
+ # Extend the prompt attention mask to account for image tokens in the final sequence
+ attention_mask = torch.cat(
+ [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
+ dim=1,
+ )
+ attention_mask = attention_mask.to(dtype)
+
+ return attention_mask
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 35,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 5.0,
+ strength: float = 0.9,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ not greater than `1`).
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 35):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ strength (`float, *optional*, defaults to 0.9):
+ Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
+ be used as a starting point, adding more noise to it the larger the strength. The number of denoising
+ steps depends on the amount of noise initially added. When strength is 1, added noise will be maximum
+ and the denoising process will run for the full number of iterations specified in num_inference_steps.
+ A value of 1, therefore, essentially ignores image.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ prompt_attention_mask (torch.Tensor, *optional*):
+ Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence.
+ Chroma requires a single padding token remain unmasked. Please refer to
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
+ negative_prompt_attention_mask (torch.Tensor, *optional*):
+ Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative
+ prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to
+ https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ strength,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_text_ids,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ attention_mask = self._prepare_attention_mask(
+ batch_size=latents.shape[0],
+ sequence_length=image_seq_len,
+ dtype=latents.dtype,
+ attention_mask=prompt_attention_mask,
+ )
+ negative_attention_mask = self._prepare_attention_mask(
+ batch_size=latents.shape[0],
+ sequence_length=image_seq_len,
+ dtype=latents.dtype,
+ attention_mask=negative_prompt_attention_mask,
+ )
+
+ # 6. Prepare image embeddings
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ attention_mask=attention_mask,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+
+ noise_pred_uncond = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ attention_mask=negative_attention_mask,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ChromaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/chroma/pipeline_output.py b/src/diffusers/pipelines/chroma/pipeline_output.py
new file mode 100644
index 000000000000..951d132dba2e
--- /dev/null
+++ b/src/diffusers/pipelines/chroma/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class ChromaPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/chronoedit/__init__.py b/src/diffusers/pipelines/chronoedit/__init__.py
new file mode 100644
index 000000000000..cffe4660977f
--- /dev/null
+++ b/src/diffusers/pipelines/chronoedit/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_chronoedit"] = ["ChronoEditPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_chronoedit import ChronoEditPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py b/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py
new file mode 100644
index 000000000000..79f6580fbed6
--- /dev/null
+++ b/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py
@@ -0,0 +1,752 @@
+# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import regex as re
+import torch
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, ChronoEditTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import ChronoEditPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> import numpy as np
+ >>> from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+ >>> from transformers import CLIPVisionModel
+
+ >>> # Available models: nvidia/ChronoEdit-14B-Diffusers
+ >>> model_id = "nvidia/ChronoEdit-14B-Diffusers"
+ >>> image_encoder = CLIPVisionModel.from_pretrained(
+ ... model_id, subfolder="image_encoder", torch_dtype=torch.float32
+ ... )
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> transformer = ChronoEditTransformer3DModel.from_pretrained(
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe = ChronoEditPipeline.from_pretrained(
+ ... model_id, vae=vae, image_encoder=image_encoder, transformer=transformer, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = load_image("https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png")
+ >>> max_area = 720 * 1280
+ >>> aspect_ratio = image.height / image.width
+ >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ >>> image = image.resize((width, height))
+ >>> prompt = (
+ ... "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
+ ... "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
+ ... )
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... height=height,
+ ... width=width,
+ ... num_frames=5,
+ ... guidance_scale=5.0,
+ ... enable_temporal_reasoning=False,
+ ... num_temporal_reasoning_steps=0,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class ChronoEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
+ the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ image_encoder: CLIPVisionModel,
+ image_processor: CLIPImageProcessor,
+ transformer: ChronoEditTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.image_processor = image_processor
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image
+ def encode_image(
+ self,
+ image: PipelineImageInput,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ # modified from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
+ )
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+
+ latent_condition = latent_condition.to(dtype)
+ latent_condition = (latent_condition - latents_mean) * latents_std
+
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
+
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ enable_temporal_reasoning: bool = False,
+ num_temporal_reasoning_steps: int = 0,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `480`):
+ The height of the generated video.
+ width (`int`, defaults to `832`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`ChronoEditPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+ enable_temporal_reasoning (`bool`, *optional*, defaults to `False`):
+ Whether to enable temporal reasoning.
+ num_temporal_reasoning_steps (`int`, *optional*, defaults to `0`):
+ The number of steps to enable temporal reasoning.
+
+ Examples:
+
+ Returns:
+ [`~ChronoEditPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`ChronoEditPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ num_frames = 5 if not enable_temporal_reasoning else num_frames
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Encode image embedding
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ if image_embeds is None:
+ image_embeds = self.encode_image(image, device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+ latents, condition = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ if enable_temporal_reasoning and i == num_temporal_reasoning_steps:
+ latents = latents[:, :, [0, -1]]
+ condition = condition[:, :, [0, -1]]
+
+ for j in range(len(self.scheduler.model_outputs)):
+ if self.scheduler.model_outputs[j] is not None:
+ if latents.shape[-3] != self.scheduler.model_outputs[j].shape[-3]:
+ self.scheduler.model_outputs[j] = self.scheduler.model_outputs[j][:, :, [0, -1]]
+ if self.scheduler.last_sample is not None:
+ self.scheduler.last_sample = self.scheduler.last_sample[:, :, [0, -1]]
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ if enable_temporal_reasoning and latents.shape[2] > 2:
+ video_edit = self.vae.decode(latents[:, :, [0, -1]], return_dict=False)[0]
+ video_reason = self.vae.decode(latents[:, :, :-1], return_dict=False)[0]
+ video = torch.cat([video_reason, video_edit[:, :, 1:]], dim=2)
+ else:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return ChronoEditPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/chronoedit/pipeline_output.py b/src/diffusers/pipelines/chronoedit/pipeline_output.py
new file mode 100644
index 000000000000..b1df5b9de35d
--- /dev/null
+++ b/src/diffusers/pipelines/chronoedit/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class ChronoEditPipelineOutput(BaseOutput):
+ r"""
+ Output class for ChronoEdit pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index 99ae9025cd3e..4ac33b24bbe1 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -359,7 +359,7 @@ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -558,11 +558,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -571,7 +571,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -645,7 +645,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -718,14 +718,15 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
index e37574ec9cb2..c1335839f848 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -398,7 +398,7 @@ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -603,11 +603,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 6.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -616,7 +616,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
control_video_latents (`torch.Tensor`, *optional*):
Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for
controlled video generation. If not provided, `control_video` must be provided.
@@ -698,7 +698,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -784,14 +784,15 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index 59d7c4cad547..c523c9adec98 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,11 +28,7 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
@@ -442,7 +438,7 @@ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -658,11 +654,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -671,7 +667,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -747,7 +743,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -831,15 +827,16 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- ofs=ofs_emb,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ ofs=ofs_emb,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index c4dc7e574f7e..897dc6d1b70a 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -418,7 +418,7 @@ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -628,11 +628,11 @@ def __call__(
strength (`float`, *optional*, defaults to 0.8):
Higher strength leads to more differences between original video and generated video.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -641,7 +641,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -718,7 +718,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -799,14 +799,15 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 0cd3943fbcd2..304a5c5ad00b 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -319,7 +319,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -390,7 +390,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -453,11 +453,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -466,7 +466,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -547,7 +547,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 8550fa94f9e4..22510f5d9d50 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -377,7 +377,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -453,11 +453,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -466,7 +466,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -619,22 +619,10 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
- noise_pred_cond = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- original_size=original_size,
- target_size=target_size,
- crop_coords=crops_coords_top_left,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- # perform guidance
- if self.do_classifier_free_guidance:
- noise_pred_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -643,6 +631,19 @@ def __call__(
return_dict=False,
)[0]
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=negative_prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
index 7613bc3d0f40..e26b7ba415de 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -409,7 +409,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -486,11 +486,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -499,7 +499,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/consisid/__init__.py b/src/diffusers/pipelines/consisid/__init__.py
index 5052e146f1df..7b9ba330fbd1 100644
--- a/src/diffusers/pipelines/consisid/__init__.py
+++ b/src/diffusers/pipelines/consisid/__init__.py
@@ -5,6 +5,7 @@
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
+ is_opencv_available,
is_torch_available,
is_transformers_available,
)
@@ -15,12 +16,12 @@
try:
- if not (is_transformers_available() and is_torch_available()):
+ if not (is_transformers_available() and is_torch_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from ...utils import dummy_torch_and_transformers_objects # noqa F403
+ from ...utils import dummy_torch_and_transformers_and_opencv_objects # noqa F403
- _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_opencv_objects))
else:
_import_structure["pipeline_consisid"] = ["ConsisIDPipeline"]
diff --git a/src/diffusers/pipelines/consisid/consisid_utils.py b/src/diffusers/pipelines/consisid/consisid_utils.py
index 874b3d76149b..521d4d787e54 100644
--- a/src/diffusers/pipelines/consisid/consisid_utils.py
+++ b/src/diffusers/pipelines/consisid/consisid_utils.py
@@ -166,7 +166,7 @@ def process_face_embeddings(
raise RuntimeError("facexlib align face fail")
align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB
- # incase insightface didn't detect face
+ # in case insightface didn't detect face
if id_ante_embedding is None:
logger.warning("Failed to detect face using insightface. Extracting embedding with align face")
id_ante_embedding = face_helper_2.get_feat(align_face)
@@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype):
Parameters:
- model_path: Path to the directory containing model files.
- - device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
+ - device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
- dtype: Data type (e.g., torch.float32) for model inference.
Returns:
diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py
index 1a99c2a0e9ee..3e6c149d7f80 100644
--- a/src/diffusers/pipelines/consisid/pipeline_consisid.py
+++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ConsisID Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@
import math
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-import cv2
import numpy as np
import PIL
import torch
@@ -29,12 +28,16 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDPMScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_opencv_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import ConsisIDPipelineOutput
+if is_opencv_available():
+ import cv2
+
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -537,7 +540,7 @@ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -712,11 +715,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 6):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
use_dynamic_cfg (`bool`, *optional*, defaults to `False`):
If True, dynamically adjusts the guidance scale during inference. This allows the model to use a
progressive guidance scale, improving the balance between text-guided generation and image quality over
@@ -730,7 +733,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -818,7 +821,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
index f0c71655e628..1fbdeb1f2741 100644
--- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
+++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,7 @@
from ...models import UNet2DModel
from ...schedulers import CMStochasticIterativeScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index a5e38278cdf2..fe0e69314cca 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,7 +37,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -579,7 +579,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -886,7 +886,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -979,8 +979,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1339,7 +1339,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
index 88c387d48dd2..e0f1879405aa 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
@@ -1,5 +1,5 @@
-# Copyright 2024 Salesforce.com, inc.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 Salesforce.com, inc.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,16 +20,12 @@
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
@@ -88,7 +84,7 @@
"""
-class BlipDiffusionControlNetPipeline(DiffusionPipeline):
+class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
@@ -116,6 +112,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
@@ -149,7 +146,7 @@ def __init__(
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
- # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
+ # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
@@ -278,13 +275,13 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by random sampling.
+ tensor will be generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
index be2874f48e69..12cc6f630d80 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -557,7 +557,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -884,7 +884,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -977,8 +977,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1311,7 +1311,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 16d3529ed38a..6de8e5747b02 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,7 +38,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -146,16 +146,13 @@ class StableDiffusionControlNetInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
-
-
- This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
+ > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
- as well as default text-to-image Stable Diffusion checkpoints
+ > as well as default text-to-image Stable Diffusion checkpoints >
([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
- Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on
- those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
-
-
+ > Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned
+ on > those, such as
+ [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
Args:
vae ([`AutoencoderKL`]):
@@ -566,7 +563,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -657,7 +654,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -665,7 +662,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
@@ -976,7 +973,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1089,8 +1086,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1500,7 +1497,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index 5907b41f4e73..fb09d04832f3 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,10 +36,6 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -51,7 +47,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -149,7 +145,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -605,7 +601,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -1132,21 +1128,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@property
def guidance_scale(self):
@@ -1157,7 +1144,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1283,11 +1270,11 @@ def __call__(
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1318,15 +1305,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1806,7 +1793,7 @@ def denoising_value_valid(dnv):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
@@ -1858,7 +1845,7 @@ def denoising_value_valid(dnv):
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 77d496cf831d..0e2a1441f8f6 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,10 +39,6 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -613,7 +609,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -929,21 +925,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -985,7 +972,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1108,8 +1095,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1465,7 +1452,11 @@ def __call__(
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ if (
+ torch.cuda.is_available()
+ and (is_unet_compiled and is_controlnet_compiled)
+ and is_torch_higher_equal_2_1
+ ):
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index 04f069e12eb9..94c4c394465b 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,10 +39,6 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -53,7 +49,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -607,7 +603,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -921,7 +917,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
image = image.to(device=device, dtype=dtype)
@@ -1044,21 +1040,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@property
def guidance_scale(self):
@@ -1069,7 +1056,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1174,11 +1161,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1189,15 +1176,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1632,7 +1619,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
index 8aae9ee7a281..e234015f8616 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
import numpy as np
import PIL.Image
import torch
-import torch.nn.functional as F
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
@@ -35,10 +34,12 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
+from ...models import (
+ AutoencoderKL,
+ ControlNetUnionModel,
+ ImageProjection,
+ MultiControlNetUnionModel,
+ UNet2DConditionModel,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -51,7 +52,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -134,7 +135,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -230,7 +231,9 @@ def __init__(
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
- controlnet: ControlNetUnionModel,
+ controlnet: Union[
+ ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
+ ],
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
@@ -240,8 +243,8 @@ def __init__(
):
super().__init__()
- if not isinstance(controlnet, ControlNetUnionModel):
- raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetUnionModel(controlnet)
self.register_modules(
vae=vae,
@@ -587,7 +590,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -660,6 +663,7 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
+ control_mode=None,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
):
@@ -747,25 +751,34 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetUnionModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
# Check `image`
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
- )
- if (
- isinstance(self.controlnet, ControlNetModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
- ):
- self.check_image(image, prompt, prompt_embeds)
- elif (
- isinstance(self.controlnet, ControlNetUnionModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
- ):
- self.check_image(image, prompt, prompt_embeds)
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
- else:
- assert False
+ if isinstance(controlnet, ControlNetUnionModel):
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+ elif not all(isinstance(i, list) for i in image):
+ raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for images_ in image:
+ for image_ in images_:
+ self.check_image(image_, prompt, prompt_embeds)
if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start]
@@ -778,6 +791,12 @@ def check_inputs(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
@@ -788,6 +807,28 @@ def check_inputs(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+ # Check `control_mode`
+ if isinstance(controlnet, ControlNetUnionModel):
+ if max(control_mode) >= controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
+ if max(_control_mode) >= _controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
+
+ # Equal number of `image` and `control_mode` elements
+ if isinstance(controlnet, ControlNetUnionModel):
+ if len(image) != len(control_mode):
+ raise ValueError("Expected len(control_image) == len(control_mode)")
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ if not all(isinstance(i, list) for i in control_mode):
+ raise ValueError(
+ "For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
+ )
+
+ elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
+ raise ValueError("Expected len(control_image) == len(control_mode)")
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -1066,21 +1107,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@property
def guidance_scale(self):
@@ -1091,7 +1123,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1117,7 +1149,7 @@ def __call__(
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
- control_image: PipelineImageInput = None,
+ control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
@@ -1145,7 +1177,7 @@ def __call__(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
- control_mode: Optional[Union[int, List[int]]] = None,
+ control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
@@ -1177,6 +1209,13 @@ def __call__(
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -1215,11 +1254,11 @@ def __call__(
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1250,15 +1289,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1269,6 +1308,22 @@ def __call__(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
+ The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
+ available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
+ where each ControlNet should have its corresponding control mode list. Should reflect the order of
+ conditions in control_image.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
@@ -1333,22 +1388,6 @@ def __call__(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
- # align format for control guidance
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
-
- # # 0.0 Default height and width to unet
- # height = height or self.unet.config.sample_size * self.vae_scale_factor
- # width = width or self.unet.config.sample_size * self.vae_scale_factor
-
- # 0.1 align format for control guidance
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
-
if not isinstance(control_image, list):
control_image = [control_image]
else:
@@ -1357,40 +1396,59 @@ def __call__(
if not isinstance(control_mode, list):
control_mode = [control_mode]
- if len(control_image) != len(control_mode):
- raise ValueError("Expected len(control_image) == len(control_type)")
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ control_image = [[item] for item in control_image]
+ control_mode = [[item] for item in control_mode]
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
- num_control_type = controlnet.config.num_control_type
+ if isinstance(controlnet_conditioning_scale, float):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
# 1. Check inputs
- control_type = [0 for _ in range(num_control_type)]
- for _image, control_idx in zip(control_image, control_mode):
- control_type[control_idx] = 1
- self.check_inputs(
- prompt,
- prompt_2,
- _image,
- mask_image,
- strength,
- num_inference_steps,
- callback_steps,
- output_type,
- negative_prompt,
- negative_prompt_2,
- prompt_embeds,
- negative_prompt_embeds,
- ip_adapter_image,
- ip_adapter_image_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- controlnet_conditioning_scale,
- control_guidance_start,
- control_guidance_end,
- callback_on_step_end_tensor_inputs,
- padding_mask_crop,
- )
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ control_image,
+ mask_image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ output_type,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ control_mode,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
- control_type = torch.Tensor(control_type)
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
+ for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
+ ]
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
@@ -1483,21 +1541,55 @@ def denoising_value_valid(dnv):
init_image = init_image.to(dtype=torch.float32)
# 5.2 Prepare control images
- for idx, _ in enumerate(control_image):
- control_image[idx] = self.prepare_control_image(
- image=control_image[idx],
- width=width,
- height=height,
- batch_size=batch_size * num_images_per_prompt,
- num_images_per_prompt=num_images_per_prompt,
- device=device,
- dtype=controlnet.dtype,
- crops_coords=crops_coords,
- resize_mode=resize_mode,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
- guess_mode=guess_mode,
- )
- height, width = control_image[idx].shape[-2:]
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_images = []
+
+ for image_ in control_image:
+ image_ = self.prepare_control_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(image_)
+
+ control_image = control_images
+ height, width = control_image[0].shape[-2:]
+
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ images = []
+
+ for image_ in control_image_:
+ image_ = self.prepare_control_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+ control_images.append(images)
+
+ control_image = control_images
+ height, width = control_image[0][0].shape[-2:]
# 5.3 Prepare mask
mask = self.mask_processor.preprocess(
@@ -1559,10 +1651,11 @@ def denoising_value_valid(dnv):
# 8.2 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
- controlnet_keep.append(
- 1.0
- - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
- )
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps)
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
height, width = latents.shape[-2:]
@@ -1627,11 +1720,24 @@ def denoising_value_valid(dnv):
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
- control_type = (
- control_type.reshape(1, -1)
- .to(device, dtype=prompt_embeds.dtype)
- .repeat(batch_size * num_images_per_prompt * 2, 1)
+ control_type_repeat_factor = (
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
)
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = (
+ control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(control_type_repeat_factor, 1)
+ )
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ _control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(control_type_repeat_factor, 1)
+ for _control_type in control_type
+ ]
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
@@ -1715,7 +1821,7 @@ def denoising_value_valid(dnv):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
@@ -1766,7 +1872,7 @@ def denoising_value_valid(dnv):
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
index ca931c221eec..40cc76cf70d8 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -44,14 +44,11 @@
MultiControlNetUnionModel,
UNet2DConditionModel,
)
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -603,7 +600,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -904,21 +901,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -960,7 +948,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1082,8 +1070,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1452,17 +1440,21 @@ def __call__(
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+ control_type_repeat_factor = (
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
+ )
+
if isinstance(controlnet, ControlNetUnionModel):
control_type = (
control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
- .repeat(batch_size * num_images_per_prompt * 2, 1)
+ .repeat(control_type_repeat_factor, 1)
)
- if isinstance(controlnet, MultiControlNetUnionModel):
+ elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
_control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
- .repeat(batch_size * num_images_per_prompt * 2, 1)
+ .repeat(control_type_repeat_factor, 1)
for _control_type in control_type
]
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
index 87398395d99e..4d0093132b9c 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +19,6 @@
import numpy as np
import PIL.Image
import torch
-import torch.nn.functional as F
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
@@ -38,10 +37,12 @@
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
+from ...models import (
+ AutoencoderKL,
+ ControlNetUnionModel,
+ ImageProjection,
+ MultiControlNetUnionModel,
+ UNet2DConditionModel,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -53,7 +54,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -262,7 +263,9 @@ def __init__(
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
- controlnet: ControlNetUnionModel,
+ controlnet: Union[
+ ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
+ ],
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
@@ -272,8 +275,8 @@ def __init__(
):
super().__init__()
- if not isinstance(controlnet, ControlNetUnionModel):
- raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetUnionModel(controlnet)
self.register_modules(
vae=vae,
@@ -616,7 +619,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -649,6 +652,7 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
+ control_mode=None,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
@@ -722,28 +726,44 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetUnionModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
# Check `image`
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
- )
- if (
- isinstance(self.controlnet, ControlNetModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
- ):
- self.check_image(image, prompt, prompt_embeds)
- elif (
- isinstance(self.controlnet, ControlNetUnionModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
- ):
- self.check_image(image, prompt, prompt_embeds)
- else:
- assert False
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+ elif not all(isinstance(i, list) for i in image):
+ raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for images_ in image:
+ for image_ in images_:
+ self.check_image(image_, prompt, prompt_embeds)
if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start]
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
if not isinstance(control_guidance_end, (tuple, list)):
control_guidance_end = [control_guidance_end]
@@ -762,6 +782,15 @@ def check_inputs(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+ # Check `control_mode`
+ if isinstance(controlnet, ControlNetUnionModel):
+ if max(control_mode) >= controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
+ if max(_control_mode) >= _controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -876,7 +905,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
image = image.to(device=device, dtype=dtype)
@@ -999,21 +1028,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@property
def guidance_scale(self):
@@ -1024,7 +1044,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1049,7 +1069,7 @@ def __call__(
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
- control_image: PipelineImageInput = None,
+ control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
@@ -1074,7 +1094,7 @@ def __call__(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
- control_mode: Optional[Union[int, List[int]]] = None,
+ control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
@@ -1104,13 +1124,13 @@ def __call__(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again.
- control_image (`PipelineImageInput`):
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
- the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
- be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
- and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
- init, images must be passed as a list such that each element of the list can be correctly batched for
- input to a single controlnet.
+ control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
height (`int`, *optional*, defaults to the size of control_image):
The height in pixels of the generated image. Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -1129,11 +1149,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1144,15 +1164,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1184,16 +1204,21 @@ def __call__(
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
- to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
- corresponding scale as a list.
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
- The percentage of total steps at which the controlnet starts applying.
+ The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
- The percentage of total steps at which the controlnet stops applying.
+ The percentage of total steps at which the ControlNet stops applying.
+ control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
+ The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
+ available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
+ where each ControlNet should have its corresponding control mode list. Should reflect the order of
+ conditions in control_image
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1273,12 +1298,6 @@ def __call__(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
- # align format for control guidance
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
-
if not isinstance(control_image, list):
control_image = [control_image]
else:
@@ -1287,37 +1306,56 @@ def __call__(
if not isinstance(control_mode, list):
control_mode = [control_mode]
- if len(control_image) != len(control_mode):
- raise ValueError("Expected len(control_image) == len(control_type)")
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ control_image = [[item] for item in control_image]
+ control_mode = [[item] for item in control_mode]
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
- num_control_type = controlnet.config.num_control_type
+ if isinstance(controlnet_conditioning_scale, float):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
# 1. Check inputs
- control_type = [0 for _ in range(num_control_type)]
- for _image, control_idx in zip(control_image, control_mode):
- control_type[control_idx] = 1
- self.check_inputs(
- prompt,
- prompt_2,
- _image,
- strength,
- num_inference_steps,
- callback_steps,
- negative_prompt,
- negative_prompt_2,
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ip_adapter_image,
- ip_adapter_image_embeds,
- controlnet_conditioning_scale,
- control_guidance_start,
- control_guidance_end,
- callback_on_step_end_tensor_inputs,
- )
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ control_image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ control_mode,
+ callback_on_step_end_tensor_inputs,
+ )
- control_type = torch.Tensor(control_type)
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
+ for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
+ ]
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
@@ -1334,7 +1372,11 @@ def __call__(
device = self._execution_device
- global_pool_conditions = controlnet.config.global_pool_conditions
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetUnionModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
guess_mode = guess_mode or global_pool_conditions
# 3.1. Encode input prompt
@@ -1372,22 +1414,55 @@ def __call__(
self.do_classifier_free_guidance,
)
- # 4. Prepare image and controlnet_conditioning_image
+ # 4.1 Prepare image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
- for idx, _ in enumerate(control_image):
- control_image[idx] = self.prepare_control_image(
- image=control_image[idx],
- width=width,
- height=height,
- batch_size=batch_size * num_images_per_prompt,
- num_images_per_prompt=num_images_per_prompt,
- device=device,
- dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
- guess_mode=guess_mode,
- )
- height, width = control_image[idx].shape[-2:]
+ # 4.2 Prepare control images
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_images = []
+
+ for image_ in control_image:
+ image_ = self.prepare_control_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(image_)
+
+ control_image = control_images
+ height, width = control_image[0].shape[-2:]
+
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ images = []
+
+ for image_ in control_image_:
+ image_ = self.prepare_control_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+ control_images.append(images)
+
+ control_image = control_images
+ height, width = control_image[0][0].shape[-2:]
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1414,10 +1489,11 @@ def __call__(
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
- controlnet_keep.append(
- 1.0
- - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
- )
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width)
@@ -1460,12 +1536,25 @@ def __call__(
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device)
- control_type = (
- control_type.reshape(1, -1)
- .to(device, dtype=prompt_embeds.dtype)
- .repeat(batch_size * num_images_per_prompt * 2, 1)
+
+ control_type_repeat_factor = (
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
)
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = (
+ control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(control_type_repeat_factor, 1)
+ )
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ _control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(control_type_repeat_factor, 1)
+ for _control_type in control_type
+ ]
+
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1574,7 +1663,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
index 3d4b19ea552c..d4c6f336dfef 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -394,12 +394,8 @@ def __call__(
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
Examples:
diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
index 5ee712b5f116..2b5684de9511 100644
--- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,11 +27,7 @@
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -144,7 +140,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -463,7 +459,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -621,7 +617,7 @@ def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -709,8 +705,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -746,7 +742,7 @@ def __call__(
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
- Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
@@ -1009,7 +1005,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
index 7f7acd882b59..d605eac1f2b1 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -266,7 +266,7 @@ def _get_t5_prompt_embeds(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -355,7 +355,7 @@ def _get_clip_prompt_embeds(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -719,7 +719,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -877,11 +877,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
@@ -918,7 +918,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
index cb35f67fa112..9d0158c6b654 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, The HuggingFace Team and The AlimamaCreative Team. All rights reserved.
+# Copyright 2025 Stability AI, The HuggingFace Team and The AlimamaCreative Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -284,7 +284,7 @@ def _get_t5_prompt_embeds(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -373,7 +373,7 @@ def _get_clip_prompt_embeds(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -769,7 +769,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -928,11 +928,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
@@ -973,7 +973,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
index 901ca25c576c..3682ddc91156 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,8 +36,8 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -98,6 +98,7 @@
class StableDiffusionControlNetXSPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
@@ -132,12 +133,13 @@ class StableDiffusionControlNetXSPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -440,7 +442,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -697,8 +699,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -739,7 +741,7 @@ def __call__(
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeine class.
+ `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
@@ -783,7 +785,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -851,7 +853,7 @@ def __call__(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if is_controlnet_compiled and is_torch_higher_equal_2_1:
+ if torch.cuda.is_available() and is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -900,7 +902,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
index acf1f5489ec1..7bf610f3a0ba 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,21 +32,18 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
-from ..pipeline_utils import DiffusionPipeline
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -114,6 +111,7 @@
class StableDiffusionXLControlNetXSPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
@@ -158,6 +156,7 @@ class StableDiffusionXLControlNetXSPipeline(
watermarker is used.
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = [
"tokenizer",
@@ -463,7 +462,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -683,21 +682,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
@@ -803,8 +793,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -880,7 +870,7 @@ def __call__(
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
- `._callback_tensor_inputs` attribute of your pipeine class.
+ `._callback_tensor_inputs` attribute of your pipeline class.
Examples:
@@ -927,7 +917,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/cosmos/__init__.py b/src/diffusers/pipelines/cosmos/__init__.py
new file mode 100644
index 000000000000..2833c89abd5e
--- /dev/null
+++ b/src/diffusers/pipelines/cosmos/__init__.py
@@ -0,0 +1,54 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_cosmos2_text2image"] = ["Cosmos2TextToImagePipeline"]
+ _import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
+ _import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
+ _import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_cosmos2_text2image import Cosmos2TextToImagePipeline
+ from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
+ from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
+ from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py
new file mode 100644
index 000000000000..66490c2be159
--- /dev/null
+++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py
@@ -0,0 +1,673 @@
+# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...models import AutoencoderKLWan, CosmosTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import CosmosImagePipelineOutput
+
+
+if is_cosmos_guardrail_available():
+ from cosmos_guardrail import CosmosSafetyChecker
+else:
+
+ class CosmosSafetyChecker:
+ def __init__(self, *args, **kwargs):
+ raise ImportError(
+ "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
+ )
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import Cosmos2TextToImagePipeline
+
+ >>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Text2Image, nvidia/Cosmos-Predict2-14B-Text2Image
+ >>> model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
+ >>> pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
+ >>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
+
+ >>> output = pipe(
+ ... prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
+ ... ).images[0]
+ >>> output.save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class Cosmos2TextToImagePipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2).
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. Cosmos uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
+ tokenizer (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CosmosTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ # We mark safety_checker as optional here to get around some test failures, but it is not really optional
+ _optional_components = ["safety_checker"]
+
+ def __init__(
+ self,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: CosmosTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ safety_checker: CosmosSafetyChecker = None,
+ ):
+ super().__init__()
+
+ if safety_checker is None:
+ safety_checker = CosmosSafetyChecker()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ self.sigma_max = 80.0
+ self.sigma_min = 0.002
+ self.sigma_data = 1.0
+ self.final_sigmas_type = "sigma_min"
+ if self.scheduler is not None:
+ self.scheduler.register_to_config(
+ sigma_max=self.sigma_max,
+ sigma_min=self.sigma_min,
+ sigma_data=self.sigma_data,
+ final_sigmas_type=self.final_sigmas_type,
+ )
+
+ # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ return_length=True,
+ return_offsets_mapping=False,
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=prompt_attention_mask
+ ).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ lengths = prompt_attention_mask.sum(dim=1).cpu()
+ for i, length in enumerate(lengths):
+ prompt_embeds[i, length:] = 0
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt with num_videos_per_prompt->num_images_per_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: 16,
+ height: int = 768,
+ width: int = 1360,
+ num_frames: int = 1,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents * self.scheduler.config.sigma_max
+
+ # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 768,
+ width: int = 1360,
+ num_inference_steps: int = 35,
+ guidance_scale: float = 7.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `768`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1360`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, defaults to `35`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `7.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`CosmosImagePipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~CosmosImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`CosmosImagePipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if self.safety_checker is None:
+ raise ValueError(
+ f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
+ "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
+ f"Please ensure that you are compliant with the license agreement."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ num_frames = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
+
+ self._guidance_scale = guidance_scale
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ if self.safety_checker is not None:
+ self.safety_checker.to(device)
+ if prompt is not None:
+ prompt_list = [prompt] if isinstance(prompt, str) else prompt
+ for p in prompt_list:
+ if not self.safety_checker.check_text_safety(p):
+ raise ValueError(
+ f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
+ f"prompt abides by the NVIDIA Open Model License Agreement."
+ )
+ self.safety_checker.to("cpu")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+ sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
+ if self.scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
+ # Replace the last sigma (which is zero) with the minimum sigma value
+ self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
+
+ # 5. Prepare latent variables
+ transformer_dtype = self.transformer.dtype
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ current_sigma = self.scheduler.sigmas[i]
+
+ current_t = current_sigma / (current_sigma + 1)
+ c_in = 1 - current_t
+ c_skip = 1 - current_t
+ c_out = -current_t
+ timestep = current_t.expand(latents.shape[0]).to(transformer_dtype) # [B, 1, T, 1, 1]
+
+ latent_model_input = latents * c_in
+ latent_model_input = latent_model_input.to(transformer_dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ padding_mask=padding_mask,
+ return_dict=False,
+ )[0]
+ noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype)
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ padding_mask=padding_mask,
+ return_dict=False,
+ )[0]
+ noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype)
+ noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond)
+
+ noise_pred = (latents - noise_pred) / current_sigma
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std / self.scheduler.config.sigma_data + latents_mean
+ video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
+
+ if self.safety_checker is not None:
+ self.safety_checker.to(device)
+ video = self.video_processor.postprocess_video(video, output_type="np")
+ video = (video * 255).astype(np.uint8)
+ video_batch = []
+ for vid in video:
+ vid = self.safety_checker.check_video_safety(vid)
+ video_batch.append(vid)
+ video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
+ video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ self.safety_checker.to("cpu")
+ else:
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ image = [batch[0] for batch in video]
+ if isinstance(video, torch.Tensor):
+ image = torch.stack(image)
+ elif isinstance(video, np.ndarray):
+ image = np.stack(image)
+ else:
+ image = latents[:, :, 0]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return CosmosImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py
new file mode 100644
index 000000000000..23a74ad00f93
--- /dev/null
+++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py
@@ -0,0 +1,792 @@
+# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...models import AutoencoderKLWan, CosmosTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import CosmosPipelineOutput
+
+
+if is_cosmos_guardrail_available():
+ from cosmos_guardrail import CosmosSafetyChecker
+else:
+
+ class CosmosSafetyChecker:
+ def __init__(self, *args, **kwargs):
+ raise ImportError(
+ "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
+ )
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import Cosmos2VideoToWorldPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Video2World, nvidia/Cosmos-Predict2-14B-Video2World
+ >>> model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
+ >>> pipe = Cosmos2VideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
+ >>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yellow-scrubber.png"
+ ... )
+
+ >>> video = pipe(
+ ... image=image, prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
+ ... ).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class Cosmos2VideoToWorldPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for video-to-world generation using [Cosmos Predict2](https://github.com/nvidia-cosmos/cosmos-predict2).
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. Cosmos uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
+ tokenizer (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CosmosTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ # We mark safety_checker as optional here to get around some test failures, but it is not really optional
+ _optional_components = ["safety_checker"]
+
+ def __init__(
+ self,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: CosmosTransformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ safety_checker: CosmosSafetyChecker = None,
+ ):
+ super().__init__()
+
+ if safety_checker is None:
+ safety_checker = CosmosSafetyChecker()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ self.sigma_max = 80.0
+ self.sigma_min = 0.002
+ self.sigma_data = 1.0
+ self.final_sigmas_type = "sigma_min"
+ if self.scheduler is not None:
+ self.scheduler.register_to_config(
+ sigma_max=self.sigma_max,
+ sigma_min=self.sigma_min,
+ sigma_data=self.sigma_data,
+ final_sigmas_type=self.final_sigmas_type,
+ )
+
+ # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ return_length=True,
+ return_offsets_mapping=False,
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=prompt_attention_mask
+ ).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ lengths = prompt_attention_mask.sum(dim=1).cpu()
+ for i, length in enumerate(lengths):
+ prompt_embeds[i, length:] = 0
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self,
+ video: torch.Tensor,
+ batch_size: int,
+ num_channels_latents: 16,
+ height: int = 704,
+ width: int = 1280,
+ num_frames: int = 93,
+ do_classifier_free_guidance: bool = True,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_cond_frames = video.size(2)
+ if num_cond_frames >= num_frames:
+ # Take the last `num_frames` frames for conditioning
+ num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ video = video[:, :, -num_frames:]
+ else:
+ num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
+ num_padding_frames = num_frames - num_cond_frames
+ last_frame = video[:, :, -1:]
+ padding = last_frame.repeat(1, 1, num_padding_frames, 1, 1)
+ video = torch.cat([video, padding], dim=2)
+
+ if isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ else:
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
+ )
+ init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.config.sigma_data
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ latents = latents * self.scheduler.config.sigma_max
+
+ padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
+ ones_padding = latents.new_ones(padding_shape)
+ zeros_padding = latents.new_zeros(padding_shape)
+
+ cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
+ cond_indicator[:, :, :num_cond_latent_frames] = 1.0
+ cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
+
+ uncond_indicator = uncond_mask = None
+ if do_classifier_free_guidance:
+ uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
+ uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
+ uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
+
+ return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
+
+ # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput = None,
+ video: List[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 704,
+ width: int = 1280,
+ num_frames: int = 93,
+ num_inference_steps: int = 35,
+ guidance_scale: float = 7.0,
+ fps: int = 16,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ sigma_conditioning: float = 0.0001,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
+ The image to be used as a conditioning input for the video generation.
+ video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
+ The video to be used as a conditioning input for the video generation.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `704`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `93`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `35`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `7.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`.
+ fps (`int`, defaults to `16`):
+ The frames per second of the generated video.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
+ the prompt is shorter than this length, it will be padded.
+ sigma_conditioning (`float`, defaults to `0.0001`):
+ The sigma value used for scaling conditioning latents. Ideally, it should not be changed or should be
+ set to a small value close to zero.
+
+ Examples:
+
+ Returns:
+ [`~CosmosPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if self.safety_checker is None:
+ raise ValueError(
+ f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
+ "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
+ f"Please ensure that you are compliant with the license agreement."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
+
+ self._guidance_scale = guidance_scale
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ if self.safety_checker is not None:
+ self.safety_checker.to(device)
+ if prompt is not None:
+ prompt_list = [prompt] if isinstance(prompt, str) else prompt
+ for p in prompt_list:
+ if not self.safety_checker.check_text_safety(p):
+ raise ValueError(
+ f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
+ f"prompt abides by the NVIDIA Open Model License Agreement."
+ )
+ self.safety_checker.to("cpu")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+ sigmas = torch.linspace(0, 1, num_inference_steps, dtype=sigmas_dtype)
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, device=device, sigmas=sigmas)
+ if self.scheduler.config.final_sigmas_type == "sigma_min":
+ # Replace the last sigma (which is zero) with the minimum sigma value
+ self.scheduler.sigmas[-1] = self.scheduler.sigmas[-2]
+
+ # 5. Prepare latent variables
+ vae_dtype = self.vae.dtype
+ transformer_dtype = self.transformer.dtype
+
+ if image is not None:
+ video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
+ else:
+ video = self.video_processor.preprocess_video(video, height, width)
+ video = video.to(device=device, dtype=vae_dtype)
+
+ num_channels_latents = self.transformer.config.in_channels - 1
+ latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ self.do_classifier_free_guidance,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+ unconditioning_latents = None
+
+ cond_mask = cond_mask.to(transformer_dtype)
+ if self.do_classifier_free_guidance:
+ uncond_mask = uncond_mask.to(transformer_dtype)
+ unconditioning_latents = conditioning_latents
+
+ padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
+ sigma_conditioning = torch.tensor(sigma_conditioning, dtype=torch.float32, device=device)
+ t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ current_sigma = self.scheduler.sigmas[i]
+
+ current_t = current_sigma / (current_sigma + 1)
+ c_in = 1 - current_t
+ c_skip = 1 - current_t
+ c_out = -current_t
+ timestep = current_t.view(1, 1, 1, 1, 1).expand(
+ latents.size(0), -1, latents.size(2), -1, -1
+ ) # [B, 1, T, 1, 1]
+
+ cond_latent = latents * c_in
+ cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
+ cond_latent = cond_latent.to(transformer_dtype)
+ cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep
+ cond_timestep = cond_timestep.to(transformer_dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=cond_latent,
+ timestep=cond_timestep,
+ encoder_hidden_states=prompt_embeds,
+ fps=fps,
+ condition_mask=cond_mask,
+ padding_mask=padding_mask,
+ return_dict=False,
+ )[0]
+ noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype)
+ noise_pred = cond_indicator * conditioning_latents + (1 - cond_indicator) * noise_pred
+
+ if self.do_classifier_free_guidance:
+ uncond_latent = latents * c_in
+ uncond_latent = uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * uncond_latent
+ uncond_latent = uncond_latent.to(transformer_dtype)
+ uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep
+ uncond_timestep = uncond_timestep.to(transformer_dtype)
+
+ noise_pred_uncond = self.transformer(
+ hidden_states=uncond_latent,
+ timestep=uncond_timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ fps=fps,
+ condition_mask=uncond_mask,
+ padding_mask=padding_mask,
+ return_dict=False,
+ )[0]
+ noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype)
+ noise_pred_uncond = (
+ uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * noise_pred_uncond
+ )
+ noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond)
+
+ noise_pred = (latents - noise_pred) / current_sigma
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
+ video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
+
+ if self.safety_checker is not None:
+ self.safety_checker.to(device)
+ video = self.video_processor.postprocess_video(video, output_type="np")
+ video = (video * 255).astype(np.uint8)
+ video_batch = []
+ for vid in video:
+ vid = self.safety_checker.check_video_safety(vid)
+ video_batch.append(vid)
+ video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
+ video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ self.safety_checker.to("cpu")
+ else:
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CosmosPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py
new file mode 100644
index 000000000000..f0aa1ecf0e0f
--- /dev/null
+++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py
@@ -0,0 +1,664 @@
+# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
+from ...schedulers import EDMEulerScheduler
+from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import CosmosPipelineOutput
+
+
+if is_cosmos_guardrail_available():
+ from cosmos_guardrail import CosmosSafetyChecker
+else:
+
+ class CosmosSafetyChecker:
+ def __init__(self, *args, **kwargs):
+ raise ImportError(
+ "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
+ )
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import CosmosTextToWorldPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
+ >>> pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
+
+ >>> output = pipe(prompt=prompt).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=30)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CosmosTextToWorldPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-world generation using [Cosmos Predict1](https://github.com/nvidia-cosmos/cosmos-predict1).
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. Cosmos uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
+ tokenizer (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CosmosTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLCosmos`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ # We mark safety_checker as optional here to get around some test failures, but it is not really optional
+ _optional_components = ["safety_checker"]
+
+ def __init__(
+ self,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: CosmosTransformer3DModel,
+ vae: AutoencoderKLCosmos,
+ scheduler: EDMEulerScheduler,
+ safety_checker: CosmosSafetyChecker = None,
+ ):
+ super().__init__()
+
+ if safety_checker is None:
+ safety_checker = CosmosSafetyChecker()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ )
+
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
+ )
+ self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ return_length=True,
+ return_offsets_mapping=False,
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=prompt_attention_mask
+ ).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ lengths = prompt_attention_mask.sum(dim=1).cpu()
+ for i, length in enumerate(lengths):
+ prompt_embeds[i, length:] = 0
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: 16,
+ height: int = 704,
+ width: int = 1280,
+ num_frames: int = 121,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents * self.scheduler.config.sigma_max
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 704,
+ width: int = 1280,
+ num_frames: int = 121,
+ num_inference_steps: int = 36,
+ guidance_scale: float = 7.0,
+ fps: int = 30,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `720`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `121`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `36`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `7.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`.
+ fps (`int`, defaults to `30`):
+ The frames per second of the generated video.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~CosmosPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if self.safety_checker is None:
+ raise ValueError(
+ f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
+ "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
+ f"Please ensure that you are compliant with the license agreement."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
+
+ self._guidance_scale = guidance_scale
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ if self.safety_checker is not None:
+ self.safety_checker.to(device)
+ if prompt is not None:
+ prompt_list = [prompt] if isinstance(prompt, str) else prompt
+ for p in prompt_list:
+ if not self.safety_checker.check_text_safety(p):
+ raise ValueError(
+ f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
+ f"prompt abides by the NVIDIA Open Model License Agreement."
+ )
+ self.safety_checker.to("cpu")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
+
+ # 5. Prepare latent variables
+ transformer_dtype = self.transformer.dtype
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ timestep = t.expand(latents.shape[0]).to(transformer_dtype)
+
+ latent_model_input = latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ latent_model_input = latent_model_input.to(transformer_dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ fps=fps,
+ padding_mask=padding_mask,
+ return_dict=False,
+ )[0]
+
+ sample = latents
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ fps=fps,
+ padding_mask=padding_mask,
+ return_dict=False,
+ )[0]
+ noise_pred = torch.cat([noise_pred_uncond, noise_pred])
+ sample = torch.cat([sample, sample])
+
+ # pred_original_sample (x0)
+ noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
+ self.scheduler._step_index -= 1
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
+ noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+
+ # pred_sample (eps)
+ latents = self.scheduler.step(
+ noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ if self.vae.config.latents_mean is not None:
+ latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
+ latents_mean = (
+ torch.tensor(latents_mean)
+ .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
+ .to(latents)
+ )
+ latents_std = (
+ torch.tensor(latents_std)
+ .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
+ .to(latents)
+ )
+ latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
+ else:
+ latents = latents / self.scheduler.config.sigma_data
+ video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
+
+ if self.safety_checker is not None:
+ self.safety_checker.to(device)
+ video = self.video_processor.postprocess_video(video, output_type="np")
+ video = (video * 255).astype(np.uint8)
+ video_batch = []
+ for vid in video:
+ vid = self.safety_checker.check_video_safety(vid)
+ video_batch.append(vid)
+ video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
+ video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ self.safety_checker.to("cpu")
+ else:
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CosmosPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py
new file mode 100644
index 000000000000..cd5a734cc311
--- /dev/null
+++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py
@@ -0,0 +1,826 @@
+# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
+from ...schedulers import EDMEulerScheduler
+from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import CosmosPipelineOutput
+
+
+if is_cosmos_guardrail_available():
+ from cosmos_guardrail import CosmosSafetyChecker
+else:
+
+ class CosmosSafetyChecker:
+ def __init__(self, *args, **kwargs):
+ raise ImportError(
+ "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
+ )
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ Image conditioning:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import CosmosVideoToWorldPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
+ >>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day."
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
+ ... )
+
+ >>> video = pipe(image=image, prompt=prompt).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=30)
+ ```
+
+ Video conditioning:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import CosmosVideoToWorldPipeline
+ >>> from diffusers.utils import export_to_video, load_video
+
+ >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
+ >>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe.transformer = torch.compile(pipe.transformer)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
+ ... )[
+ ... :21
+ ... ] # This example uses only the first 21 frames
+
+ >>> video = pipe(video=video, prompt=prompt).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=30)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class CosmosVideoToWorldPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image-to-world and video-to-world generation using [Cosmos
+ Predict-1](https://github.com/nvidia-cosmos/cosmos-predict1).
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. Cosmos uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
+ tokenizer (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CosmosTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLCosmos`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ # We mark safety_checker as optional here to get around some test failures, but it is not really optional
+ _optional_components = ["safety_checker"]
+
+ def __init__(
+ self,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ transformer: CosmosTransformer3DModel,
+ vae: AutoencoderKLCosmos,
+ scheduler: EDMEulerScheduler,
+ safety_checker: CosmosSafetyChecker = None,
+ ):
+ super().__init__()
+
+ if safety_checker is None:
+ safety_checker = CosmosSafetyChecker()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ )
+
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
+ )
+ self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ return_length=True,
+ return_offsets_mapping=False,
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=prompt_attention_mask
+ ).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ lengths = prompt_attention_mask.sum(dim=1).cpu()
+ for i, length in enumerate(lengths):
+ prompt_embeds[i, length:] = 0
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self,
+ video: torch.Tensor,
+ batch_size: int,
+ num_channels_latents: 16,
+ height: int = 704,
+ width: int = 1280,
+ num_frames: int = 121,
+ do_classifier_free_guidance: bool = True,
+ input_frames_guidance: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_cond_frames = video.size(2)
+ if num_cond_frames >= num_frames:
+ # Take the last `num_frames` frames for conditioning
+ num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ video = video[:, :, -num_frames:]
+ else:
+ num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
+ num_padding_frames = num_frames - num_cond_frames
+ padding = video.new_zeros(video.size(0), video.size(1), num_padding_frames, video.size(3), video.size(4))
+ video = torch.cat([video, padding], dim=2)
+
+ if isinstance(generator, list):
+ init_latents = [
+ retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
+ for i in range(batch_size)
+ ]
+ else:
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+
+ if self.vae.config.latents_mean is not None:
+ latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
+ latents_mean = (
+ torch.tensor(latents_mean)
+ .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
+ .to(init_latents)
+ )
+ latents_std = (
+ torch.tensor(latents_std)
+ .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
+ .to(init_latents)
+ )
+ init_latents = (init_latents - latents_mean) * self.scheduler.config.sigma_data / latents_std
+ else:
+ init_latents = init_latents * self.scheduler.config.sigma_data
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ latents = latents * self.scheduler.config.sigma_max
+
+ padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
+ ones_padding = latents.new_ones(padding_shape)
+ zeros_padding = latents.new_zeros(padding_shape)
+
+ cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
+ cond_indicator[:, :, :num_cond_latent_frames] = 1.0
+ cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
+
+ uncond_indicator = uncond_mask = None
+ if do_classifier_free_guidance:
+ uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
+ uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
+ uncond_mask = zeros_padding
+ if not input_frames_guidance:
+ uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
+
+ return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ image=None,
+ video=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if image is None and video is None:
+ raise ValueError("Either `image` or `video` has to be provided.")
+ if image is not None and video is not None:
+ raise ValueError("Only one of `image` or `video` has to be provided.")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput = None,
+ video: List[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 704,
+ width: int = 1280,
+ num_frames: int = 121,
+ num_inference_steps: int = 36,
+ guidance_scale: float = 7.0,
+ input_frames_guidance: bool = False,
+ augment_sigma: float = 0.001,
+ fps: int = 30,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `720`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `121`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `36`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `7.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`.
+ fps (`int`, defaults to `30`):
+ The frames per second of the generated video.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~CosmosPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if self.safety_checker is None:
+ raise ValueError(
+ f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
+ "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
+ f"Please ensure that you are compliant with the license agreement."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, image, video)
+
+ self._guidance_scale = guidance_scale
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ if self.safety_checker is not None:
+ self.safety_checker.to(device)
+ if prompt is not None:
+ prompt_list = [prompt] if isinstance(prompt, str) else prompt
+ for p in prompt_list:
+ if not self.safety_checker.check_text_safety(p):
+ raise ValueError(
+ f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
+ f"prompt abides by the NVIDIA Open Model License Agreement."
+ )
+ self.safety_checker.to("cpu")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
+
+ # 5. Prepare latent variables
+ vae_dtype = self.vae.dtype
+ transformer_dtype = self.transformer.dtype
+
+ if image is not None:
+ video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
+ else:
+ video = self.video_processor.preprocess_video(video, height, width)
+ video = video.to(device=device, dtype=vae_dtype)
+
+ num_channels_latents = self.transformer.config.in_channels - 1
+ latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ self.do_classifier_free_guidance,
+ input_frames_guidance,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+ cond_mask = cond_mask.to(transformer_dtype)
+ if self.do_classifier_free_guidance:
+ uncond_mask = uncond_mask.to(transformer_dtype)
+
+ augment_sigma = torch.tensor([augment_sigma], device=device, dtype=torch.float32)
+ padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ timestep = t.expand(latents.shape[0]).to(transformer_dtype)
+
+ current_sigma = self.scheduler.sigmas[i]
+ is_augment_sigma_greater = augment_sigma >= current_sigma
+
+ c_in_augment = self.scheduler._get_conditioning_c_in(augment_sigma)
+ c_in_original = self.scheduler._get_conditioning_c_in(current_sigma)
+
+ current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
+ cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
+ cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None]
+ cond_latent = cond_latent * c_in_augment / c_in_original
+ cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents
+ cond_latent = self.scheduler.scale_model_input(cond_latent, t)
+ cond_latent = cond_latent.to(transformer_dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=cond_latent,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ fps=fps,
+ condition_mask=cond_mask,
+ padding_mask=padding_mask,
+ return_dict=False,
+ )[0]
+
+ sample = latents
+ if self.do_classifier_free_guidance:
+ current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
+ uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
+ uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
+ uncond_latent = uncond_latent * c_in_augment / c_in_original
+ uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
+ uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
+ uncond_latent = uncond_latent.to(transformer_dtype)
+
+ noise_pred_uncond = self.transformer(
+ hidden_states=uncond_latent,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ fps=fps,
+ condition_mask=uncond_mask,
+ padding_mask=padding_mask,
+ return_dict=False,
+ )[0]
+ noise_pred = torch.cat([noise_pred_uncond, noise_pred])
+ sample = torch.cat([sample, sample])
+
+ # pred_original_sample (x0)
+ noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
+ self.scheduler._step_index -= 1
+
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
+ noise_pred_uncond = (
+ current_uncond_indicator * conditioning_latents
+ + (1 - current_uncond_indicator) * noise_pred_uncond
+ )
+ noise_pred_cond = (
+ current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond
+ )
+ noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ else:
+ noise_pred = (
+ current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred
+ )
+
+ # pred_sample (eps)
+ latents = self.scheduler.step(
+ noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ if self.vae.config.latents_mean is not None:
+ latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
+ latents_mean = (
+ torch.tensor(latents_mean)
+ .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
+ .to(latents)
+ )
+ latents_std = (
+ torch.tensor(latents_std)
+ .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
+ .to(latents)
+ )
+ latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
+ else:
+ latents = latents / self.scheduler.config.sigma_data
+ video = self.vae.decode(latents.to(vae_dtype), return_dict=False)[0]
+
+ if self.safety_checker is not None:
+ self.safety_checker.to(device)
+ video = self.video_processor.postprocess_video(video, output_type="np")
+ video = (video * 255).astype(np.uint8)
+ video_batch = []
+ for vid in video:
+ vid = self.safety_checker.check_video_safety(vid)
+ video_batch.append(vid)
+ video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
+ video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ self.safety_checker.to("cpu")
+ else:
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CosmosPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/cosmos/pipeline_output.py b/src/diffusers/pipelines/cosmos/pipeline_output.py
new file mode 100644
index 000000000000..ec5f4826f62a
--- /dev/null
+++ b/src/diffusers/pipelines/cosmos/pipeline_output.py
@@ -0,0 +1,40 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from diffusers.utils import BaseOutput, get_logger
+
+
+logger = get_logger(__name__)
+
+
+@dataclass
+class CosmosPipelineOutput(BaseOutput):
+ r"""
+ Output class for Cosmos any-to-world/video pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
+
+
+@dataclass
+class CosmosImagePipelineOutput(BaseOutput):
+ """
+ Output class for Cosmos any-to-image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
index 34b2a3945572..5a70c4f5ff9a 100644
--- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
+++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
from ...schedulers import SchedulerMixin
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline
if is_torch_xla_available():
@@ -34,7 +34,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class DanceDiffusionPipeline(DiffusionPipeline):
+class DanceDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
r"""
Pipeline for audio generation.
@@ -49,6 +49,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
[`IPNDMScheduler`].
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "unet"
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
@@ -97,7 +98,7 @@ def __call__(
for i, audio in enumerate(audios):
write(f"maestro_test_{i}.wav", pipe.unet.sample_rate, audio.transpose())
- # To dislay in google colab
+ # To display in google colab
import IPython.display as ipd
for audio in audios:
diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py
index 1fd8ce4e6570..39587ca5221d 100644
--- a/src/diffusers/pipelines/ddim/pipeline_ddim.py
+++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -77,9 +77,9 @@ def __call__(
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to
- DDIM and `1` corresponds to DDPM.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0`
+ corresponds to DDIM and `1` corresponds to DDPM.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
index 1c5ac4baeae0..0d7766a8cfd0 100644
--- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
index 150978de6e5e..8fa31f8504d3 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
@@ -336,7 +336,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -484,7 +484,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -581,11 +581,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -597,8 +597,8 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -655,7 +655,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
index a92d7be6a11c..507927faf61b 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
@@ -361,7 +361,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -528,7 +528,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -706,11 +706,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 10.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -718,8 +718,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -775,7 +775,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
index b23ea39bb292..9bc15c3c6f62 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
@@ -281,7 +281,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -514,7 +514,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -793,11 +793,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -805,8 +805,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -870,7 +870,7 @@ def __call__(
# 2. Define call parameters
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
index 030821b789aa..9d6cf62020a9 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
@@ -365,7 +365,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -568,7 +568,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -806,11 +806,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -818,8 +818,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -882,7 +882,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
index bdad9c29b18f..0122c164d8b8 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
@@ -283,7 +283,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -516,7 +516,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -889,11 +889,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -901,8 +901,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -967,7 +967,7 @@ def __call__(
# 2. Define call parameters
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
index 012c4ca6d448..ffa60575fe33 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
@@ -239,7 +239,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -472,7 +472,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -656,11 +656,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -668,8 +668,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -739,7 +739,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
index 48c0aa4f6d76..6f484aa3e298 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -68,7 +68,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -185,8 +185,8 @@ class AltDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -266,8 +266,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -557,7 +557,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -686,7 +686,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -754,8 +754,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -780,7 +780,7 @@ def __call__(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -942,7 +942,7 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
index fa70689d790d..d6bf90120755 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -213,8 +213,8 @@ class AltDiffusionImg2ImgPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -294,8 +294,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -585,7 +585,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -754,7 +754,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -828,8 +828,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/deprecated/audio_diffusion/mel.py b/src/diffusers/pipelines/deprecated/audio_diffusion/mel.py
index 3426c3ad0428..0902f993a060 100644
--- a/src/diffusers/pipelines/deprecated/audio_diffusion/mel.py
+++ b/src/diffusers/pipelines/deprecated/audio_diffusion/mel.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py b/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py
index 47044e050acf..81fa999eb1fb 100644
--- a/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -115,8 +115,8 @@ def __call__(
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) used to denoise.
None
eta (`float`):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
noise (`torch.Tensor`):
A noise tensor of shape `(batch_size, 1, height, width)` or `None`.
encoding (`torch.Tensor`):
diff --git a/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
index 7fe5d59f771d..0bb24ed0b1ce 100644
--- a/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
+++ b/src/diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py b/src/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py
index ef78af1940ce..71e3e156e0e4 100644
--- a/src/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py
+++ b/src/diffusers/pipelines/deprecated/pndm/pipeline_pndm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -95,7 +95,7 @@ def __call__(
returned where the first element is a list with the generated images.
"""
# For more information on the sampling method you can take a look at Algorithm 2 of
- # the official paper: https://arxiv.org/pdf/2202.09778.pdf
+ # the official paper: https://huggingface.co/papers/2202.09778
# Sample gaussian noise to begin loop
image = randn_tensor(
diff --git a/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py b/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py
index 843528a532f1..56c6007ae886 100644
--- a/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py
+++ b/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -124,10 +124,11 @@ def __call__(
DDIM and 1.0 is the DDPM scheduler.
jump_length (`int`, *optional*, defaults to 10):
The number of steps taken forward in time before going backward in time for a single jump ("j" in
- RePaint paper). Take a look at Figure 9 and 10 in the [paper](https://arxiv.org/pdf/2201.09865.pdf).
+ RePaint paper). Take a look at Figure 9 and 10 in the
+ [paper](https://huggingface.co/papers/2201.09865).
jump_n_sample (`int`, *optional*, defaults to 10):
The number of times to make a forward time jump for a given chosen time sample. Take a look at Figure 9
- and 10 in the [paper](https://arxiv.org/pdf/2201.09865.pdf).
+ and 10 in the [paper](https://huggingface.co/papers/2201.09865).
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py
index b0bb114a81b7..3f04db7ad699 100644
--- a/src/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py
+++ b/src/diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py
index 8664c2fb6711..b26e84f72869 100644
--- a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py
+++ b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py
@@ -1,5 +1,5 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py
index e777e844935e..8985a6f88800 100644
--- a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py
+++ b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py
@@ -1,5 +1,5 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py
index 1259f0bf056a..25ad4a4ccfd2 100644
--- a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py
+++ b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py
@@ -1,5 +1,5 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
index b8ac8e1416bf..be07b1b15ea8 100644
--- a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py
@@ -1,5 +1,5 @@
# Copyright 2022 The Music Spectrogram Diffusion Authors.
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
index 1752540e8f79..08f8c7e26fae 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -115,7 +115,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
@@ -127,7 +127,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
variance = scheduler._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
noise = (prev_latents - (alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction)) / (
@@ -162,8 +162,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Sta
instance of [`DDIMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -226,8 +226,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -522,7 +522,7 @@ def check_inputs(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -678,8 +678,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -717,7 +717,7 @@ def __call__(
from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
- # make sure you're logged in with `huggingface-cli login`
+ # make sure you're logged in with `hf auth login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
@@ -790,7 +790,7 @@ def __call__(
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
index e9553a8d99b0..fcd8bf317adf 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
@@ -62,7 +62,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -337,19 +338,19 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (?) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (?) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
prompt_embeds (`np.ndarray`, *optional*):
@@ -404,7 +405,7 @@ def __call__(
image = preprocess(image)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -455,7 +456,7 @@ def __call__(
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (?) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to ? in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to ? in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
index f9c9c37c4867..ba0dd66c2938 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -111,7 +111,8 @@ class StableDiffusionInpaintPipelineLegacy(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -196,8 +197,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -468,7 +469,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -605,11 +606,11 @@ def __call__(
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
@@ -620,8 +621,8 @@ def __call__(
Use predicted noise instead of random noise when constructing noisy versions of the original image in
the reverse diffusion process
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -672,7 +673,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
index 06db871daf62..b7a0be57c12b 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TIME Authors and The HuggingFace Team. All rights reserved."
+# Copyright 2025 TIME Authors and The HuggingFace Team. All rights reserved."
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@@ -64,8 +64,8 @@ class StableDiffusionModelEditingPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
with_to_k ([`bool`]):
@@ -402,7 +402,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -500,7 +500,8 @@ def edit_model(
restart_params: bool = True,
):
r"""
- Apply model editing via closed-form solution (see Eq. 5 in the TIME [paper](https://arxiv.org/abs/2303.08084)).
+ Apply model editing via closed-form solution (see Eq. 5 in the TIME
+ [paper](https://huggingface.co/papers/2303.08084)).
Args:
source_prompt (`str`):
@@ -509,7 +510,8 @@ def edit_model(
The destination prompt. Must contain all words from `source_prompt` with additional ones to specify the
target edit.
lamb (`float`, *optional*, defaults to 0.1):
- The lambda parameter specifying the regularization intesity. Smaller values increase the editing power.
+ The lambda parameter specifying the regularization intensity. Smaller values increase the editing
+ power.
restart_params (`bool`, *optional*, defaults to True):
Restart the model parameters to their pre-trained version before editing. This is done to avoid edit
compounding. When it is `False`, edits accumulate.
@@ -574,7 +576,7 @@ def edit_model(
idxs_replace.append(76)
idxs_replaces.append(idxs_replace)
- # prepare batch: for each pair of setences, old context and new values
+ # prepare batch: for each pair of sentences, old context and new values
contexts, valuess = [], []
for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces):
context = old_emb.detach()
@@ -653,8 +655,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -731,7 +733,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
index d486a32f6a4c..c236e73bf448 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,10 +46,12 @@
>>> from diffusers import DDPMParallelScheduler
>>> from diffusers import StableDiffusionParadigmsPipeline
- >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
+ >>> scheduler = DDPMParallelScheduler.from_pretrained(
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler"
+ ... )
>>> pipe = StableDiffusionParadigmsPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
@@ -95,8 +97,8 @@ class StableDiffusionParadigmsPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -385,7 +387,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -537,8 +539,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -599,7 +601,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
index 509f25620950..2a461ae20cc9 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -303,7 +303,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
requires_safety_checker (bool):
@@ -616,7 +617,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -848,10 +849,10 @@ def __call__(
instead.
source_embeds (`torch.Tensor`):
Source concept embeddings. Generation of the embeddings as per the [original
- paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction.
+ paper](https://huggingface.co/papers/2302.03027). Used in discovering the edit direction.
target_embeds (`torch.Tensor`):
Target concept embeddings. Generation of the embeddings as per the [original
- paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction.
+ paper](https://huggingface.co/papers/2302.03027). Used in discovering the edit direction.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -860,11 +861,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -872,15 +873,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -939,7 +940,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -1140,18 +1141,18 @@ def invert(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 1):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1199,7 +1200,7 @@ def invert(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
index 023edb4ce4bd..50b8b0bcbc1d 100644
--- a/src/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
+++ b/src/diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
index bc276811ff4a..7c25713cd1d7 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
@@ -374,21 +374,21 @@ def __init__(
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
- down_block_types: Tuple[str] = (
+ down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"CrossAttnDownBlockFlat",
"DownBlockFlat",
),
mid_block_type: Optional[str] = "UNetMidBlockFlatCrossAttn",
- up_block_types: Tuple[str] = (
+ up_block_types: Tuple[str, ...] = (
"UpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
"CrossAttnUpBlockFlat",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
@@ -964,7 +964,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
fn_recursive_set_attention_slice(module, reversed_slice_size)
def enable_freeu(self, s1, s2, b1, b2):
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+ r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
@@ -1000,11 +1000,7 @@ def fuse_qkv_projections(self):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1021,11 +1017,7 @@ def fuse_qkv_projections(self):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
@@ -1097,7 +1089,7 @@ def forward(
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
- A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
index 4fb437958abd..9ff8e9857791 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
@@ -38,8 +38,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -118,8 +118,8 @@ def image_variation(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -230,8 +230,8 @@ def text_to_image(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -339,8 +339,8 @@ def dual_guided(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
index 0065279bc0b1..0252f4f6af7f 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -315,7 +315,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -424,8 +424,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -493,7 +493,7 @@ def __call__(
batch_size = len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
index 7dfc7e961825..034a0226419b 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
import numpy as np
import PIL.Image
import torch
-import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ....image_processor import VaeImageProcessor
@@ -175,7 +174,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -276,8 +275,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -338,7 +337,7 @@ def __call__(
batch_size = 1 if isinstance(image, PIL.Image.Image) else len(image)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
index 1d6771793f39..2f54f4fc98a4 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@
from typing import Callable, List, Optional, Union
import torch
-import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer
from ....image_processor import VaeImageProcessor
@@ -232,7 +231,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -362,8 +361,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -416,7 +415,7 @@ def __call__(
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py
index 8dee000df05f..e8617a54b691 100644
--- a/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Microsoft and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py
index 8aee0fadaf69..68ff6c9b559a 100644
--- a/src/diffusers/pipelines/dit/pipeline_dit.py
+++ b/src/diffusers/pipelines/dit/pipeline_dit.py
@@ -4,7 +4,7 @@
# Copyright (c) 2021 OpenAI
# MIT License
#
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,7 +46,9 @@ class DiTPipeline(DiffusionPipeline):
Parameters:
transformer ([`DiTTransformer2DModel`]):
- A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents.
+ A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents. Initially published as
+ [`Transformer2DModel`](https://huggingface.co/facebook/DiT-XL-2-256/blob/main/transformer/config.json#L2)
+ in the config, but the mismatch can be ignored.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
scheduler ([`DDIMScheduler`]):
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py
index 25975b04f395..92239c0d32f0 100755
--- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py
@@ -101,7 +101,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -404,7 +404,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -507,7 +507,7 @@ def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -732,7 +732,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py
index 1d2c508675f1..f74a11f87d75 100755
--- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py
@@ -177,7 +177,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -515,7 +515,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -653,7 +653,7 @@ def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -956,7 +956,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
index 15745ecca3f0..b16ef92d8e6b 100755
--- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
+++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py
@@ -199,7 +199,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -557,7 +557,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -771,7 +771,7 @@ def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -849,7 +849,7 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
+ A parameter defined in the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies to the
[`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the
inference process.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -883,7 +883,8 @@ def __call__(
inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
strength (`float`, *optional*, defaults to 1.0):
Affects the overall styling or quality of the generated output. Values closer to 1 usually provide
direct adherence to prompts.
@@ -1130,7 +1131,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input."
)
@@ -1180,7 +1181,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py
index 72e1b578f2ca..ea25c148e2f1 100644
--- a/src/diffusers/pipelines/flux/__init__.py
+++ b/src/diffusers/pipelines/flux/__init__.py
@@ -33,6 +33,8 @@
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
+ _import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"]
+ _import_structure["pipeline_flux_kontext_inpaint"] = ["FluxKontextInpaintPipeline"]
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -52,6 +54,8 @@
from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
+ from .pipeline_flux_kontext import FluxKontextPipeline
+ from .pipeline_flux_kontext_inpaint import FluxKontextInpaintPipeline
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/flux/modeling_flux.py b/src/diffusers/pipelines/flux/modeling_flux.py
index 5ff60f774d19..d7f2f45812b3 100644
--- a/src/diffusers/pipelines/flux/modeling_flux.py
+++ b/src/diffusers/pipelines/flux/modeling_flux.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 862c279cfaf3..5041e352f73d 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -310,7 +311,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -490,14 +491,6 @@ def check_inputs(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
- if prompt_embeds is not None and negative_prompt_embeds is not None:
- if prompt_embeds.shape != negative_prompt_embeds.shape:
- raise ValueError(
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
- f" {negative_prompt_embeds.shape}."
- )
-
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
@@ -553,6 +546,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -560,6 +559,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -568,6 +573,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -575,6 +586,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -682,7 +699,8 @@ def __call__(
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
- When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
+ `negative_prompt` is provided.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -695,11 +713,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -708,7 +726,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -821,7 +839,7 @@ def __call__(
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
- _,
+ negative_text_ids,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
@@ -848,6 +866,8 @@ def __call__(
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
+ sigmas = None
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
@@ -906,6 +926,9 @@ def __call__(
)
# 6. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
@@ -917,32 +940,35 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latents,
- timestep=timestep / 1000,
- guidance=guidance,
- pooled_projections=pooled_prompt_embeds,
- encoder_hidden_states=prompt_embeds,
- txt_ids=text_ids,
- img_ids=latent_image_ids,
- joint_attention_kwargs=self.joint_attention_kwargs,
- return_dict=False,
- )[0]
-
- if do_true_cfg:
- if negative_image_embeds is not None:
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
- neg_noise_pred = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
- pooled_projections=negative_pooled_prompt_embeds,
- encoder_hidden_states=negative_prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py
index 113b0dd7291f..848d7bd39254 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -163,9 +164,9 @@ class FluxControlPipeline(
TextualInversionLoaderMixin,
):
r"""
- The Flux pipeline for controllable text-to-image generation.
+ The Flux pipeline for controllable text-to-image generation with image conditions.
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+ Reference: https://bfl.ai/flux-1-tools
Args:
transformer ([`FluxTransformer2DModel`]):
@@ -324,7 +325,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -496,6 +497,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -503,6 +510,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -511,6 +524,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -518,6 +537,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
@@ -661,11 +686,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with prompt at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -674,7 +699,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
index c269be15a4b2..262345c75afc 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -335,7 +335,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -699,11 +699,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -712,7 +712,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
index af7e8b53fad3..6915a83a7ca7 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,6 +35,7 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -374,7 +375,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -577,6 +578,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -584,6 +591,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -592,6 +605,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -599,6 +618,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -838,7 +863,7 @@ def __call__(
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -857,11 +882,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -870,7 +895,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
index f3f1d90204d6..507ec687347c 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -341,7 +341,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -733,11 +733,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
@@ -764,7 +764,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
index ddd5372b4dd8..582c7bbad84e 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
@@ -335,7 +335,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -687,7 +687,8 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598).
control_mode (`int` or `List[int]`, *optional*):
The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
@@ -800,17 +801,20 @@ def __call__(
)
height, width = control_image.shape[-2:]
- control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
- control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
+ controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
+ if self.controlnet.input_hint_block is None:
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
- height_control_image, width_control_image = control_image.shape[2:]
- control_image = self._pack_latents(
- control_image,
- batch_size * num_images_per_prompt,
- num_channels_latents,
- height_control_image,
- width_control_image,
- )
+ height_control_image, width_control_image = control_image.shape[2:]
+ control_image = self._pack_latents(
+ control_image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
@@ -819,7 +823,9 @@ def __call__(
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
- for control_image_ in control_image:
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
+ controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
+ for i, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
@@ -831,17 +837,18 @@ def __call__(
)
height, width = control_image_.shape[-2:]
- control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
- control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ if self.controlnet.nets[0].input_hint_block is None:
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
- height_control_image, width_control_image = control_image_.shape[2:]
- control_image_ = self._pack_latents(
- control_image_,
- batch_size * num_images_per_prompt,
- num_channels_latents,
- height_control_image,
- width_control_image,
- )
+ height_control_image, width_control_image = control_image_.shape[2:]
+ control_image_ = self._pack_latents(
+ control_image_,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height_control_image,
+ width_control_image,
+ )
control_images.append(control_image_)
@@ -955,6 +962,7 @@ def __call__(
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]
latents_dtype = latents.dtype
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index bff625367bc9..f7f34ef231af 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -346,7 +346,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -507,7 +507,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -515,7 +515,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -801,7 +801,8 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598).
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
index 1816b3ca6d9b..5cb9c82204b2 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -224,11 +225,13 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
- latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
+ )
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
- vae_latent_channels=latent_channels,
+ vae_latent_channels=self.latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
@@ -417,7 +420,7 @@ def prepare_mask_latents(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -493,10 +496,38 @@ def encode_prompt(
return prompt_embeds, pooled_prompt_embeds, text_ids
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
def check_inputs(
self,
prompt,
prompt_2,
+ strength,
height,
width,
prompt_embeds=None,
@@ -507,6 +538,9 @@ def check_inputs(
mask_image=None,
masked_image_latents=None,
):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
@@ -600,6 +634,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -607,6 +647,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -615,6 +661,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -622,11 +674,19 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
- # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
+ # Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
+ image,
+ timestep,
batch_size,
num_channels_latents,
height,
@@ -636,28 +696,41 @@ def prepare_latents(
generator,
latents=None,
):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
-
shape = (batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None:
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
- if isinstance(generator, list) and len(generator) != batch_size:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
-
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
-
return latents, latent_image_ids
@property
@@ -687,6 +760,7 @@ def __call__(
masked_image_latents: Optional[torch.FloatTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
+ strength: float = 1.0,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
@@ -726,11 +800,17 @@ def __call__(
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -739,11 +819,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 30.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -752,7 +832,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -794,6 +874,7 @@ def __call__(
self.check_inputs(
prompt,
prompt_2,
+ strength,
height,
width,
prompt_embeds=prompt_embeds,
@@ -809,6 +890,9 @@ def __call__(
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -838,9 +922,37 @@ def __call__(
lora_scale=lora_scale,
)
- # 4. Prepare latent variables
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents, latent_image_ids = self.prepare_latents(
+ init_image,
+ latent_timestep,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
@@ -851,17 +963,16 @@ def __call__(
latents,
)
- # 5. Prepare mask and masked image latents
+ # 6. Prepare mask and masked image latents
if masked_image_latents is not None:
masked_image_latents = masked_image_latents.to(latents.device)
else:
- image = self.image_processor.preprocess(image, height=height, width=width)
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
- masked_image = image * (1 - mask_image)
+ masked_image = init_image * (1 - mask_image)
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
- height, width = image.shape[-2:]
+ height, width = init_image.shape[-2:]
mask, masked_image_latents = self.prepare_mask_latents(
mask_image,
masked_image,
@@ -876,23 +987,6 @@ def __call__(
)
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
- # 6. Prepare timesteps
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
- image_seq_len = latents.shape[1]
- mu = calculate_shift(
- image_seq_len,
- self.scheduler.config.get("base_image_seq_len", 256),
- self.scheduler.config.get("max_image_seq_len", 4096),
- self.scheduler.config.get("base_shift", 0.5),
- self.scheduler.config.get("max_shift", 1.15),
- )
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler,
- num_inference_steps,
- device,
- sigmas=sigmas,
- mu=mu,
- )
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
index 64cd6ac45f1a..ab9140dae921 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,6 +33,7 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -333,7 +334,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -607,6 +608,63 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
return latents
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
def prepare_latents(
self,
image,
@@ -741,11 +799,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -754,7 +812,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index 27b9e0cd45fa..3bfe82cf4382 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -337,7 +337,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -574,7 +574,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -582,7 +582,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -834,7 +834,7 @@ def __call__(
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -860,11 +860,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -873,7 +873,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1193,6 +1193,11 @@ def __call__(
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
+ if padding_mask_crop is not None:
+ image = [
+ self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
+ ]
+
# Offload all models
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
new file mode 100644
index 000000000000..94ae460afcd0
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
@@ -0,0 +1,1159 @@
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxKontextPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ... ).convert("RGB")
+ >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
+ >>> image = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... guidance_scale=2.5,
+ ... generator=torch.Generator().manual_seed(42),
+ ... ).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+PREFERRED_KONTEXT_RESOLUTIONS = [
+ (672, 1568),
+ (688, 1504),
+ (720, 1456),
+ (752, 1392),
+ (800, 1328),
+ (832, 1248),
+ (880, 1184),
+ (944, 1104),
+ (1024, 1024),
+ (1104, 944),
+ (1184, 880),
+ (1248, 832),
+ (1328, 800),
+ (1392, 752),
+ (1456, 720),
+ (1504, 688),
+ (1568, 672),
+]
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FluxKontextPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Flux Kontext pipeline for image-to-image and text-to-image generation.
+
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor],
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ image_latents = image_ids = None
+ if image is not None:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latent_height, image_latent_width = image_latents.shape[2:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ image_ids = self._prepare_latent_image_ids(
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_ids[..., 0] = 1
+
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents, image_latents, latent_ids, image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ max_area: int = 1024**2,
+ _auto_resize: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with prompt at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+ max_area (`int`, defaults to `1024 ** 2`):
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
+ area while maintaining the aspect ratio.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_height, original_width = height, width
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 3. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ img = image[0] if isinstance(image, list) else image
+ image_height, image_width = self.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ image = self.image_processor.resize(image, image_height, image_width)
+ image = self.image_processor.preprocess(image, image_height, image_width)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
+ image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ if image_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
new file mode 100644
index 000000000000..b6f957981e14
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
@@ -0,0 +1,1485 @@
+# Copyright 2025 ZenAI. All rights reserved.
+# author: @vuongminh1907
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ # Inpainting with text only
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> prompt = "Change the yellow dinosaur to green one"
+ >>> img_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
+ ... )
+ >>> mask_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
+ ... )
+
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
+ >>> image.save("kontext_inpainting_normal.png")
+ ```
+
+ # Inpainting with image conditioning
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> prompt = "Replace this ball"
+ >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
+ >>> mask_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
+ ... )
+ >>> image_reference_url = (
+ ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
+ ... )
+
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image_reference = load_image(image_reference_url)
+
+ >>> mask = pipe.mask_processor.blur(mask, blur_factor=12)
+ >>> image = pipe(
+ ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
+ ... ).images[0]
+ >>> image.save("kontext_inpainting_ref.png")
+ ```
+"""
+
+PREFERRED_KONTEXT_RESOLUTIONS = [
+ (672, 1568),
+ (688, 1504),
+ (720, 1456),
+ (752, 1392),
+ (800, 1328),
+ (832, 1248),
+ (880, 1184),
+ (944, 1104),
+ (1024, 1024),
+ (1104, 944),
+ (1184, 880),
+ (1248, 832),
+ (1328, 800),
+ (1392, 752),
+ (1456, 720),
+ (1504, 688),
+ (1568, 672),
+]
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FluxKontextInpaintPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Flux Kontext pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor],
+ timestep: int,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ image_reference: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ # Prepare image latents
+ image_latents = image_ids = None
+ if image is not None:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ # Prepare image reference latents
+ image_reference_latents = image_reference_ids = None
+ if image_reference is not None:
+ image_reference = image_reference.to(device=device, dtype=dtype)
+ if image_reference.shape[1] != self.latent_channels:
+ image_reference_latents = self._encode_vae_image(image=image_reference, generator=generator)
+ else:
+ image_reference_latents = image_reference
+ if batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_reference_latents.shape[0]
+ image_reference_latents = torch.cat([image_reference_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image_reference` of batch size {image_reference_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_reference_latents = torch.cat([image_reference_latents], dim=0)
+
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device=device, dtype=dtype)
+ latents = noise
+
+ image_latent_height, image_latent_width = image_latents.shape[2:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ image_ids = self._prepare_latent_image_ids(
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_ids[..., 0] = 1
+
+ if image_reference_latents is not None:
+ image_reference_latent_height, image_reference_latent_width = image_reference_latents.shape[2:]
+ image_reference_latents = self._pack_latents(
+ image_reference_latents,
+ batch_size,
+ num_channels_latents,
+ image_reference_latent_height,
+ image_reference_latent_width,
+ )
+ image_reference_ids = self._prepare_latent_image_ids(
+ batch_size, image_reference_latent_height // 2, image_reference_latent_width // 2, device, dtype
+ )
+ # image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_reference_ids[..., 0] = 1
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (
+ masked_image_latents - self.vae.config.shift_factor
+ ) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ image_reference: Optional[PipelineImageInput] = None,
+ mask_image: PipelineImageInput = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ padding_mask_crop: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ max_area: int = 1024**2,
+ _auto_resize: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image
+ to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point for the
+ masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If
+ it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is
+ a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can
+ also accept image latents as `image`, but if passing latents directly it is not encoded again.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
+ `negative_prompt` is provided.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+ max_area (`int`, defaults to `1024 ** 2`):
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
+ area while maintaining the aspect ratio.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_height, original_width = height, width
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ img = image[0] if isinstance(image, list) else image
+ image_height, image_width = self.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ image = self.image_processor.resize(image, image_height, image_width)
+
+ # Choose the resolution of the image to be the same as the image
+ width = image_width
+ height = image_height
+
+ # 2.1 Preprocess mask
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ image = self.image_processor.preprocess(
+ image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ else:
+ raise ValueError("image must be provided correctly for inpainting")
+
+ init_image = image.to(dtype=torch.float32)
+
+ # 2.1 Preprocess image_reference
+ if image_reference is not None and not (
+ isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels
+ ):
+ if (
+ isinstance(image_reference, list)
+ and isinstance(image_reference[0], torch.Tensor)
+ and image_reference[0].ndim == 4
+ ):
+ image_reference = torch.cat(image_reference, dim=0)
+ img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference
+ image_reference_height, image_reference_width = self.image_processor.get_default_height_width(
+ img_reference
+ )
+ aspect_ratio = image_reference_width / image_reference_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_reference_width, image_reference_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_reference_width = image_reference_width // multiple_of * multiple_of
+ image_reference_height = image_reference_height // multiple_of * multiple_of
+ image_reference = self.image_processor.resize(
+ image_reference, image_reference_height, image_reference_width
+ )
+ image_reference = self.image_processor.preprocess(
+ image_reference,
+ image_reference_height,
+ image_reference_width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+ else:
+ image_reference = None
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = (
+ self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image_reference,
+ )
+ )
+
+ if image_reference_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_reference_ids], dim=0) # dim 0 is sequence dimension
+ elif image_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ masked_image = init_image * (mask_condition < 0.5)
+
+ mask, _ = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ latent_model_input = latents
+ if image_reference_latents is not None:
+ latent_model_input = torch.cat([latents, image_reference_latents], dim=1)
+ elif image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
index f53958df2ed0..e79db337b2e3 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -292,7 +292,7 @@ def _get_clip_prompt_embeds(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
diff --git a/src/diffusers/pipelines/flux/pipeline_output.py b/src/diffusers/pipelines/flux/pipeline_output.py
index 388824e89f87..69e742d3e026 100644
--- a/src/diffusers/pipelines/flux/pipeline_output.py
+++ b/src/diffusers/pipelines/flux/pipeline_output.py
@@ -11,12 +11,14 @@
@dataclass
class FluxPipelineOutput(BaseOutput):
"""
- Output class for Stable Diffusion pipelines.
+ Output class for Flux image generation pipelines.
Args:
- images (`List[PIL.Image.Image]` or `np.ndarray`)
- List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
- num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
+ passed to the decoder.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py
new file mode 100644
index 000000000000..d986c9a63011
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["Flux2PipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_flux2 import Flux2Pipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py
new file mode 100644
index 000000000000..f1a8742491f7
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/image_processor.py
@@ -0,0 +1,178 @@
+# Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List
+
+import PIL.Image
+
+from ...configuration_utils import register_to_config
+from ...image_processor import VaeImageProcessor
+
+
+class Flux2ImageProcessor(VaeImageProcessor):
+ r"""
+ Image processor to preprocess the reference (character) image for the Flux2 model.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
+ vae_scale_factor (`int`, *optional*, defaults to `16`):
+ VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
+ this factor.
+ vae_latent_channels (`int`, *optional*, defaults to `32`):
+ VAE latent channels.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image to [-1,1].
+ do_convert_rgb (`bool`, *optional*, defaults to be `True`):
+ Whether to convert the images to RGB format.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 16,
+ vae_latent_channels: int = 32,
+ do_normalize: bool = True,
+ do_convert_rgb: bool = True,
+ ):
+ super().__init__(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ do_normalize=do_normalize,
+ do_convert_rgb=do_convert_rgb,
+ )
+
+ @staticmethod
+ def check_image_input(
+ image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024
+ ) -> PIL.Image.Image:
+ """
+ Check if image meets minimum size and aspect ratio requirements.
+
+ Args:
+ image: PIL Image to validate
+ max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width)
+ min_side_length: Minimum pixels required for width and height
+ max_area: Maximum allowed area in pixels²
+
+ Returns:
+ The input image if valid
+
+ Raises:
+ ValueError: If image is too small or aspect ratio is too extreme
+ """
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
+
+ width, height = image.size
+
+ # Check minimum dimensions
+ if width < min_side_length or height < min_side_length:
+ raise ValueError(
+ f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px"
+ )
+
+ # Check aspect ratio
+ aspect_ratio = max(width / height, height / width)
+ if aspect_ratio > max_aspect_ratio:
+ raise ValueError(
+ f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). "
+ f"Maximum allowed ratio is {max_aspect_ratio}:1"
+ )
+
+ return image
+
+ @staticmethod
+ def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
+ image_width, image_height = image.size
+
+ scale = math.sqrt(target_area / (image_width * image_height))
+ width = int(image_width * scale)
+ height = int(image_height * scale)
+
+ return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
+
+ @staticmethod
+ def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image:
+ image_width, image_height = image.size
+ pixel_count = image_width * image_height
+ if pixel_count <= target_area:
+ return image
+ return Flux2ImageProcessor._resize_to_target_area(image, target_area)
+
+ def _resize_and_crop(
+ self,
+ image: PIL.Image.Image,
+ width: int,
+ height: int,
+ ) -> PIL.Image.Image:
+ r"""
+ center crop the image to the specified width and height.
+
+ Args:
+ image (`PIL.Image.Image`):
+ The image to resize and crop.
+ width (`int`):
+ The width to resize the image to.
+ height (`int`):
+ The height to resize the image to.
+
+ Returns:
+ `PIL.Image.Image`:
+ The resized and cropped image.
+ """
+ image_width, image_height = image.size
+
+ left = (image_width - width) // 2
+ top = (image_height - height) // 2
+ right = left + width
+ bottom = top + height
+
+ return image.crop((left, top, right, bottom))
+
+ # Taken from
+ # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L310C1-L339C19
+ @staticmethod
+ def concatenate_images(images: List[PIL.Image.Image]) -> PIL.Image.Image:
+ """
+ Concatenate a list of PIL images horizontally with center alignment and white background.
+ """
+
+ # If only one image, return a copy of it
+ if len(images) == 1:
+ return images[0].copy()
+
+ # Convert all images to RGB if not already
+ images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
+
+ # Calculate dimensions for horizontal concatenation
+ total_width = sum(img.width for img in images)
+ max_height = max(img.height for img in images)
+
+ # Create new image with white background
+ background_color = (255, 255, 255)
+ new_img = PIL.Image.new("RGB", (total_width, max_height), background_color)
+
+ # Paste images with center alignment
+ x_offset = 0
+ for img in images:
+ y_offset = (max_height - img.height) // 2
+ new_img.paste(img, (x_offset, y_offset))
+ x_offset += img.width
+
+ return new_img
diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py
new file mode 100644
index 000000000000..b54a43dd89a5
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py
@@ -0,0 +1,1032 @@
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import AutoProcessor, Mistral3ForConditionalGeneration
+
+from ...loaders import Flux2LoraLoaderMixin
+from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .image_processor import Flux2ImageProcessor
+from .pipeline_output import Flux2PipelineOutput
+from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import Flux2Pipeline
+
+ >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
+ >>> image.save("flux.png")
+ ```
+"""
+
+UPSAMPLING_MAX_IMAGE_SIZE = 768**2
+
+
+# Adapted from
+# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
+def format_input(
+ prompts: List[str],
+ system_message: str = SYSTEM_MESSAGE,
+ images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
+):
+ """
+ Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images
+ to the input.
+
+ Args:
+ prompts: List of text prompts
+ system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
+ images (optional): List of images to add to the input.
+
+ Returns:
+ List of conversations, where each conversation is a list of message dicts
+ """
+ # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
+ # when truncation is enabled. The processor counts [IMG] tokens and fails
+ # if the count changes after truncation.
+ cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
+
+ if images is None or len(images) == 0:
+ return [
+ [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": system_message}],
+ },
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
+ ]
+ for prompt in cleaned_txt
+ ]
+ else:
+ assert len(images) == len(prompts), "Number of images must match number of prompts"
+ messages = [
+ [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": system_message}],
+ },
+ ]
+ for _ in cleaned_txt
+ ]
+
+ for i, (el, images) in enumerate(zip(messages, images)):
+ # optionally add the images per batch element.
+ if images is not None:
+ el.append(
+ {
+ "role": "user",
+ "content": [{"type": "image", "image": image_obj} for image_obj in images],
+ }
+ )
+ # add the text.
+ el.append(
+ {
+ "role": "user",
+ "content": [{"type": "text", "text": cleaned_txt[i]}],
+ }
+ )
+
+ return messages
+
+
+# Adapted from
+# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19
+def _validate_and_process_images(
+ images: List[List[PIL.Image.Image]] | List[PIL.Image.Image],
+ image_processor: Flux2ImageProcessor,
+ upsampling_max_image_size: int,
+) -> List[List[PIL.Image.Image]]:
+ # Simple validation: ensure it's a list of PIL images or list of lists of PIL images
+ if not images:
+ return []
+
+ # Check if it's a list of lists or a list of images
+ if isinstance(images[0], PIL.Image.Image):
+ # It's a list of images, convert to list of lists
+ images = [[im] for im in images]
+
+ # potentially concatenate multiple images to reduce the size
+ images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images]
+
+ # cap the pixels
+ images = [
+ [image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i]
+ for img_i in images
+ ]
+ return images
+
+
+# Taken from
+# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251
+def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
+ a1, b1 = 8.73809524e-05, 1.89833333
+ a2, b2 = 0.00016927, 0.45666666
+
+ if image_seq_len > 4300:
+ mu = a2 * image_seq_len + b2
+ return float(mu)
+
+ m_200 = a2 * image_seq_len + b2
+ m_10 = a1 * image_seq_len + b1
+
+ a = (m_200 - m_10) / 190.0
+ b = m_200 - 200.0 * a
+ mu = a * num_steps + b
+
+ return float(mu)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
+ r"""
+ The Flux2 pipeline for text-to-image generation.
+
+ Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)
+
+ Args:
+ transformer ([`Flux2Transformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLFlux2`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Mistral3ForConditionalGeneration`]):
+ [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
+ tokenizer (`AutoProcessor`):
+ Tokenizer of class
+ [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLFlux2,
+ text_encoder: Mistral3ForConditionalGeneration,
+ tokenizer: AutoProcessor,
+ transformer: Flux2Transformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ transformer=transformer,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 512
+ self.default_sample_size = 128
+
+ self.system_message = SYSTEM_MESSAGE
+ self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I
+ self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I
+ self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
+
+ @staticmethod
+ def _get_mistral_3_small_prompt_embeds(
+ text_encoder: Mistral3ForConditionalGeneration,
+ tokenizer: AutoProcessor,
+ prompt: Union[str, List[str]],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 512,
+ system_message: str = SYSTEM_MESSAGE,
+ hidden_states_layers: List[int] = (10, 20, 30),
+ ):
+ dtype = text_encoder.dtype if dtype is None else dtype
+ device = text_encoder.device if device is None else device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ # Format input messages
+ messages_batch = format_input(prompts=prompt, system_message=system_message)
+
+ # Process all messages at once
+ inputs = tokenizer.apply_chat_template(
+ messages_batch,
+ add_generation_prompt=False,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_sequence_length,
+ )
+
+ # Move to device
+ input_ids = inputs["input_ids"].to(device)
+ attention_mask = inputs["attention_mask"].to(device)
+
+ # Forward pass through the model
+ output = text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ use_cache=False,
+ )
+
+ # Only use outputs from intermediate layers and stack them
+ out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
+ out = out.to(dtype=dtype, device=device)
+
+ batch_size, num_channels, seq_len, hidden_dim = out.shape
+ prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
+
+ return prompt_embeds
+
+ @staticmethod
+ def _prepare_text_ids(
+ x: torch.Tensor, # (B, L, D) or (L, D)
+ t_coord: Optional[torch.Tensor] = None,
+ ):
+ B, L, _ = x.shape
+ out_ids = []
+
+ for i in range(B):
+ t = torch.arange(1) if t_coord is None else t_coord[i]
+ h = torch.arange(1)
+ w = torch.arange(1)
+ l = torch.arange(L)
+
+ coords = torch.cartesian_prod(t, h, w, l)
+ out_ids.append(coords)
+
+ return torch.stack(out_ids)
+
+ @staticmethod
+ def _prepare_latent_ids(
+ latents: torch.Tensor, # (B, C, H, W)
+ ):
+ r"""
+ Generates 4D position coordinates (T, H, W, L) for latent tensors.
+
+ Args:
+ latents (torch.Tensor):
+ Latent tensor of shape (B, C, H, W)
+
+ Returns:
+ torch.Tensor:
+ Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
+ H=[0..H-1], W=[0..W-1], L=0
+ """
+
+ batch_size, _, height, width = latents.shape
+
+ t = torch.arange(1) # [0] - time dimension
+ h = torch.arange(height)
+ w = torch.arange(width)
+ l = torch.arange(1) # [0] - layer dimension
+
+ # Create position IDs: (H*W, 4)
+ latent_ids = torch.cartesian_prod(t, h, w, l)
+
+ # Expand to batch: (B, H*W, 4)
+ latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
+
+ return latent_ids
+
+ @staticmethod
+ def _prepare_image_ids(
+ image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
+ scale: int = 10,
+ ):
+ r"""
+ Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
+
+ This function creates a unique coordinate for every pixel/patch across all input latent with different
+ dimensions.
+
+ Args:
+ image_latents (List[torch.Tensor]):
+ A list of image latent feature tensors, typically of shape (C, H, W).
+ scale (int, optional):
+ A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
+ latent is: 'scale + scale * i'. Defaults to 10.
+
+ Returns:
+ torch.Tensor:
+ The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
+ input latents.
+
+ Coordinate Components (Dimension 4):
+ - T (Time): The unique index indicating which latent image the coordinate belongs to.
+ - H (Height): The row index within that latent image.
+ - W (Width): The column index within that latent image.
+ - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
+ """
+
+ if not isinstance(image_latents, list):
+ raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
+
+ # create time offset for each reference image
+ t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
+ t_coords = [t.view(-1) for t in t_coords]
+
+ image_latent_ids = []
+ for x, t in zip(image_latents, t_coords):
+ x = x.squeeze(0)
+ _, height, width = x.shape
+
+ x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
+ image_latent_ids.append(x_ids)
+
+ image_latent_ids = torch.cat(image_latent_ids, dim=0)
+ image_latent_ids = image_latent_ids.unsqueeze(0)
+
+ return image_latent_ids
+
+ @staticmethod
+ def _patchify_latents(latents):
+ batch_size, num_channels_latents, height, width = latents.shape
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
+ latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
+ return latents
+
+ @staticmethod
+ def _unpatchify_latents(latents):
+ batch_size, num_channels_latents, height, width = latents.shape
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
+ return latents
+
+ @staticmethod
+ def _pack_latents(latents):
+ """
+ pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
+ """
+
+ batch_size, num_channels, height, width = latents.shape
+ latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
+ """
+ using position ids to scatter tokens into place
+ """
+ x_list = []
+ for data, pos in zip(x, x_ids):
+ _, ch = data.shape # noqa: F841
+ h_ids = pos[:, 1].to(torch.int64)
+ w_ids = pos[:, 2].to(torch.int64)
+
+ h = torch.max(h_ids) + 1
+ w = torch.max(w_ids) + 1
+
+ flat_ids = h_ids * w + w_ids
+
+ out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
+ out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
+
+ # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
+
+ out = out.view(h, w, ch).permute(2, 0, 1)
+ x_list.append(out)
+
+ return torch.stack(x_list, dim=0)
+
+ def upsample_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None,
+ temperature: float = 0.15,
+ device: torch.device = None,
+ ) -> List[str]:
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ device = self.text_encoder.device if device is None else device
+
+ # Set system message based on whether images are provided
+ if images is None or len(images) == 0 or images[0] is None:
+ system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
+ else:
+ system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
+
+ # Validate and process the input images
+ if images:
+ images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size)
+
+ # Format input messages
+ messages_batch = format_input(prompts=prompt, system_message=system_message, images=images)
+
+ # Process all messages at once
+ # with image processing a too short max length can throw an error in here.
+ inputs = self.tokenizer.apply_chat_template(
+ messages_batch,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=2048,
+ )
+
+ # Move to device
+ inputs["input_ids"] = inputs["input_ids"].to(device)
+ inputs["attention_mask"] = inputs["attention_mask"].to(device)
+
+ if "pixel_values" in inputs:
+ inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype)
+
+ # Generate text using the model's generate method
+ generated_ids = self.text_encoder.generate(
+ **inputs,
+ max_new_tokens=512,
+ do_sample=True,
+ temperature=temperature,
+ use_cache=True,
+ )
+
+ # Decode only the newly generated tokens (skip input tokens)
+ # Extract only the generated portion
+ input_length = inputs["input_ids"].shape[1]
+ generated_tokens = generated_ids[:, input_length:]
+
+ upsampled_prompt = self.tokenizer.tokenizer.batch_decode(
+ generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
+ )
+ return upsampled_prompt
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
+ ):
+ device = device or self._execution_device
+
+ if prompt is None:
+ prompt = ""
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_mistral_3_small_prompt_embeds(
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ system_message=self.system_message,
+ hidden_states_layers=text_encoder_out_layers,
+ )
+
+ batch_size, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ text_ids = self._prepare_text_ids(prompt_embeds)
+ text_ids = text_ids.to(device)
+ return prompt_embeds, text_ids
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if image.ndim != 4:
+ raise ValueError(f"Expected image dims 4, got {image.ndim}.")
+
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ image_latents = self._patchify_latents(image_latents)
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
+ image_latents = (image_latents - latents_bn_mean) / latents_bn_std
+
+ return image_latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_latents_channels,
+ height,
+ width,
+ dtype,
+ device,
+ generator: torch.Generator,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ latent_ids = self._prepare_latent_ids(latents)
+ latent_ids = latent_ids.to(device)
+
+ latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
+ return latents, latent_ids
+
+ def prepare_image_latents(
+ self,
+ images: List[torch.Tensor],
+ batch_size,
+ generator: torch.Generator,
+ device,
+ dtype,
+ ):
+ image_latents = []
+ for image in images:
+ image = image.to(device=device, dtype=dtype)
+ imagge_latent = self._encode_vae_image(image=image, generator=generator)
+ image_latents.append(imagge_latent) # (1, 128, 32, 32)
+
+ image_latent_ids = self._prepare_image_ids(image_latents)
+
+ # Pack each latent and concatenate
+ packed_latents = []
+ for latent in image_latents:
+ # latent: (1, 128, 32, 32)
+ packed = self._pack_latents(latent) # (1, 1024, 128)
+ packed = packed.squeeze(0) # (1024, 128) - remove batch dim
+ packed_latents.append(packed)
+
+ # Concatenate all reference tokens along sequence dimension
+ image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
+ image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
+
+ image_latents = image_latents.repeat(batch_size, 1, 1)
+ image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
+ image_latent_ids = image_latent_ids.to(device)
+
+ return image_latents, image_latent_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if (
+ height is not None
+ and height % (self.vae_scale_factor * 2) != 0
+ or width is not None
+ and width % (self.vae_scale_factor * 2) != 0
+ ):
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = 4.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ text_encoder_out_layers: Tuple[int] = (10, 20, 30),
+ caption_upsample_temperature: float = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ guidance_scale (`float`, *optional*, defaults to 1.0):
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+ text_encoder_out_layers (`Tuple[int]`):
+ Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
+ caption_upsample_temperature (`float`):
+ When specified, we will try to perform caption upsampling for potentially improved outputs. We
+ recommend setting it to 0.15 if caption upsampling is to be performed.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. prepare text embeddings
+ if caption_upsample_temperature:
+ prompt = self.upsample_prompt(
+ prompt, images=image, temperature=caption_upsample_temperature, device=device
+ )
+ prompt_embeds, text_ids = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ text_encoder_out_layers=text_encoder_out_layers,
+ )
+
+ # 4. process images
+ if image is not None and not isinstance(image, list):
+ image = [image]
+
+ condition_images = None
+ if image is not None:
+ for img in image:
+ self.image_processor.check_image_input(img)
+
+ condition_images = []
+ for img in image:
+ image_width, image_height = img.size
+ if image_width * image_height > 1024 * 1024:
+ img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
+ image_width, image_height = img.size
+
+ multiple_of = self.vae_scale_factor * 2
+ image_width = (image_width // multiple_of) * multiple_of
+ image_height = (image_height // multiple_of) * multiple_of
+ img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
+ condition_images.append(img)
+ height = height or image_height
+ width = width or image_width
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 5. prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_ids = self.prepare_latents(
+ batch_size=batch_size * num_images_per_prompt,
+ num_latents_channels=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ image_latents = None
+ image_latent_ids = None
+ if condition_images is not None:
+ image_latents, image_latent_ids = self.prepare_image_latents(
+ images=condition_images,
+ batch_size=batch_size * num_images_per_prompt,
+ generator=generator,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ # 6. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
+ sigmas = None
+ image_seq_len = latents.shape[1]
+ mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+
+ # 7. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ latent_model_input = latents.to(self.transformer.dtype)
+ latent_image_ids = latent_ids
+
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype)
+ latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input, # (B, image_seq_len, C)
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids, # B, text_seq_len, 4
+ img_ids=latent_image_ids, # B, image_seq_len, 4
+ joint_attention_kwargs=self._attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ noise_pred = noise_pred[:, : latents.size(1) :]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents_with_ids(latents, latent_ids)
+
+ latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
+ latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
+ latents.device, latents.dtype
+ )
+ latents = latents * latents_bn_std + latents_bn_mean
+ latents = self._unpatchify_latents(latents)
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return Flux2PipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux2/pipeline_output.py b/src/diffusers/pipelines/flux2/pipeline_output.py
new file mode 100644
index 000000000000..58e8ad49c210
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/pipeline_output.py
@@ -0,0 +1,23 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class Flux2PipelineOutput(BaseOutput):
+ """
+ Output class for Flux2 image generation pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
+ passed to the decoder.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/flux2/system_messages.py b/src/diffusers/pipelines/flux2/system_messages.py
new file mode 100644
index 000000000000..ecdb1371f0d4
--- /dev/null
+++ b/src/diffusers/pipelines/flux2/system_messages.py
@@ -0,0 +1,33 @@
+# docstyle-ignore
+"""
+These system prompts come from:
+https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54
+"""
+
+# docstyle-ignore
+SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
+attribution and actions without speculation."""
+
+# docstyle-ignore
+SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
+
+Guidelines:
+1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
+2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
+3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
+
+Output only the revised prompt and nothing else."""
+
+# docstyle-ignore
+SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
+
+Rules:
+- Single instruction only, no commentary
+- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
+- Specify what changes AND what stays the same (face, lighting, composition)
+- Reference actual image elements
+- Turn negatives into positives ("don't change X" → "keep X")
+- Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels")
+- Keep content PG-13
+
+Output only the final instruction in plain text and nothing else."""
diff --git a/src/diffusers/pipelines/free_init_utils.py b/src/diffusers/pipelines/free_init_utils.py
index 1fb67592ca4f..4495c5ea2683 100644
--- a/src/diffusers/pipelines/free_init_utils.py
+++ b/src/diffusers/pipelines/free_init_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@ def enable_free_init(
spatial_stop_frequency: float = 0.25,
temporal_stop_frequency: float = 0.25,
):
- """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
+ """Enables the FreeInit mechanism as in https://huggingface.co/papers/2312.07537.
This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
diff --git a/src/diffusers/pipelines/free_noise_utils.py b/src/diffusers/pipelines/free_noise_utils.py
index dc0071a494e3..2910afaf237b 100644
--- a/src/diffusers/pipelines/free_noise_utils.py
+++ b/src/diffusers/pipelines/free_noise_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -143,7 +143,7 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
class AnimateDiffFreeNoiseMixin:
- r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
+ r"""Mixin class for [FreeNoise](https://huggingface.co/papers/2310.15169)."""
def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to enable FreeNoise in transformer blocks."""
@@ -341,9 +341,9 @@ def _encode_prompt_free_noise(
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
- negative_prompt_interpolation_embeds[
- start_frame : end_frame + 1
- ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
+ negative_prompt_interpolation_embeds[start_frame : end_frame + 1] = (
+ self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
+ )
prompt_embeds = prompt_interpolation_embeds
negative_prompt_embeds = negative_prompt_interpolation_embeds
@@ -478,7 +478,7 @@ def enable_free_noise(
Must be one of ["shuffle_context", "repeat_context", "random"].
- "shuffle_context"
Shuffles a fixed batch of `context_length` latents to create a final latent of size
- `num_frames`. This is usually the best setting for most generation scenarious. However, there
+ `num_frames`. This is usually the best setting for most generation scenarios. However, there
might be visible repetition noticeable in the kinds of motion/animation generated.
- "repeated_context"
Repeats a fixed batch of `context_length` latents to create a final latent of size
diff --git a/src/diffusers/pipelines/hidream_image/__init__.py b/src/diffusers/pipelines/hidream_image/__init__.py
new file mode 100644
index 000000000000..498df900e68b
--- /dev/null
+++ b/src/diffusers/pipelines/hidream_image/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["HiDreamImagePipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_hidream_image"] = ["HiDreamImagePipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_hidream_image import HiDreamImagePipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
new file mode 100644
index 000000000000..b6af23bca8fd
--- /dev/null
+++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
@@ -0,0 +1,1050 @@
+# Copyright 2025 HiDream-ai Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import (
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ LlamaForCausalLM,
+ PreTrainedTokenizerFast,
+ T5EncoderModel,
+ T5Tokenizer,
+)
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import HiDreamImageLoraLoaderMixin
+from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HiDreamImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+ >>> from diffusers import HiDreamImagePipeline
+
+
+ >>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
+ >>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
+ ... "meta-llama/Meta-Llama-3.1-8B-Instruct",
+ ... output_hidden_states=True,
+ ... output_attentions=True,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+
+ >>> pipe = HiDreamImagePipeline.from_pretrained(
+ ... "HiDream-ai/HiDream-I1-Full",
+ ... tokenizer_4=tokenizer_4,
+ ... text_encoder_4=text_encoder_4,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> image = pipe(
+ ... 'A cat holding a sign that says "Hi-Dreams.ai".',
+ ... height=1024,
+ ... width=1024,
+ ... guidance_scale=5.0,
+ ... num_inference_steps=50,
+ ... generator=torch.Generator("cuda").manual_seed(0),
+ ... ).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer_2: CLIPTokenizer,
+ text_encoder_3: T5EncoderModel,
+ tokenizer_3: T5Tokenizer,
+ text_encoder_4: LlamaForCausalLM,
+ tokenizer_4: PreTrainedTokenizerFast,
+ transformer: HiDreamImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ text_encoder_3=text_encoder_3,
+ text_encoder_4=text_encoder_4,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ tokenizer_3=tokenizer_3,
+ tokenizer_4=tokenizer_4,
+ scheduler=scheduler,
+ transformer=transformer,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ # HiDreamImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.default_sample_size = 128
+ if getattr(self, "tokenizer_4", None) is not None:
+ self.tokenizer_4.pad_token = self.tokenizer_4.eos_token
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_3.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer_3(
+ prompt,
+ padding="max_length",
+ max_length=min(max_sequence_length, self.tokenizer_3.model_max_length),
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_3.batch_decode(
+ untruncated_ids[:, min(max_sequence_length, self.tokenizer_3.model_max_length) - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {min(max_sequence_length, self.tokenizer_3.model_max_length)} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=attention_mask.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ return prompt_embeds
+
+ def _get_clip_prompt_embeds(
+ self,
+ tokenizer,
+ text_encoder,
+ prompt: Union[str, List[str]],
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=min(max_sequence_length, 218),
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, 218 - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {218} tokens: {removed_text}"
+ )
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ return prompt_embeds
+
+ def _get_llama3_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_4.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ text_inputs = self.tokenizer_4(
+ prompt,
+ padding="max_length",
+ max_length=min(max_sequence_length, self.tokenizer_4.model_max_length),
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask
+ untruncated_ids = self.tokenizer_4(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_4.batch_decode(
+ untruncated_ids[:, min(max_sequence_length, self.tokenizer_4.model_max_length) - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {min(max_sequence_length, self.tokenizer_4.model_max_length)} tokens: {removed_text}"
+ )
+
+ outputs = self.text_encoder_4(
+ text_input_ids.to(device),
+ attention_mask=attention_mask.to(device),
+ output_hidden_states=True,
+ output_attentions=True,
+ )
+
+ prompt_embeds = outputs.hidden_states[1:]
+ prompt_embeds = torch.stack(prompt_embeds, dim=0)
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_4: Optional[Union[str, List[str]]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ negative_prompt_4: Optional[Union[str, List[str]]] = None,
+ prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
+ prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
+ negative_prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
+ negative_prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 128,
+ lora_scale: Optional[float] = None,
+ ):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = pooled_prompt_embeds.shape[0]
+
+ device = device or self._execution_device
+
+ if pooled_prompt_embeds is None:
+ pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
+ self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
+ )
+
+ if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if len(negative_prompt) > 1 and len(negative_prompt) != batch_size:
+ raise ValueError(f"negative_prompt must be of length 1 or {batch_size}")
+
+ negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
+ self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype
+ )
+
+ if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1:
+ negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1)
+
+ if pooled_prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ if len(prompt_2) > 1 and len(prompt_2) != batch_size:
+ raise ValueError(f"prompt_2 must be of length 1 or {batch_size}")
+
+ pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
+ self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype
+ )
+
+ if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
+ pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1)
+
+ if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+ negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+
+ if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size:
+ raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}")
+
+ negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
+ self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype
+ )
+
+ if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
+ negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1)
+
+ if pooled_prompt_embeds is None:
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
+
+ if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
+ negative_pooled_prompt_embeds = torch.cat(
+ [negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1
+ )
+
+ if prompt_embeds_t5 is None:
+ prompt_3 = prompt_3 or prompt
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
+
+ if len(prompt_3) > 1 and len(prompt_3) != batch_size:
+ raise ValueError(f"prompt_3 must be of length 1 or {batch_size}")
+
+ prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
+
+ if prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
+ prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds_t5 is None:
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
+ negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
+
+ if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size:
+ raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}")
+
+ negative_prompt_embeds_t5 = self._get_t5_prompt_embeds(
+ negative_prompt_3, max_sequence_length, device, dtype
+ )
+
+ if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
+ negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
+
+ if prompt_embeds_llama3 is None:
+ prompt_4 = prompt_4 or prompt
+ prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
+
+ if len(prompt_4) > 1 and len(prompt_4) != batch_size:
+ raise ValueError(f"prompt_4 must be of length 1 or {batch_size}")
+
+ prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
+
+ if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
+ prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
+
+ if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None:
+ negative_prompt_4 = negative_prompt_4 or negative_prompt
+ negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
+
+ if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size:
+ raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}")
+
+ negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds(
+ negative_prompt_4, max_sequence_length, device, dtype
+ )
+
+ if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
+ negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
+
+ # duplicate pooled_prompt_embeds for each generation per prompt
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ # duplicate t5_prompt_embeds for batch_size and num_images_per_prompt
+ bs_embed, seq_len, _ = prompt_embeds_t5.shape
+ if bs_embed == 1 and batch_size > 1:
+ prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
+ elif bs_embed > 1 and bs_embed != batch_size:
+ raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}")
+ prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt
+ _, bs_embed, seq_len, dim = prompt_embeds_llama3.shape
+ if bs_embed == 1 and batch_size > 1:
+ prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
+ elif bs_embed > 1 and bs_embed != batch_size:
+ raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}")
+ prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
+ prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
+
+ if do_classifier_free_guidance:
+ # duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt
+ bs_embed, seq_len = negative_pooled_prompt_embeds.shape
+ if bs_embed == 1 and batch_size > 1:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1)
+ elif bs_embed > 1 and bs_embed != batch_size:
+ raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}")
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ # duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt
+ bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape
+ if bs_embed == 1 and batch_size > 1:
+ negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
+ elif bs_embed > 1 and bs_embed != batch_size:
+ raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}")
+ negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt
+ _, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape
+ if bs_embed == 1 and batch_size > 1:
+ negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
+ elif bs_embed > 1 and bs_embed != batch_size:
+ raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}")
+ negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
+ negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view(
+ -1, batch_size * num_images_per_prompt, seq_len, dim
+ )
+
+ return (
+ prompt_embeds_t5,
+ negative_prompt_embeds_t5,
+ prompt_embeds_llama3,
+ negative_prompt_embeds_llama3,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ )
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ prompt_3,
+ prompt_4,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ negative_prompt_3=None,
+ negative_prompt_4=None,
+ prompt_embeds_t5=None,
+ prompt_embeds_llama3=None,
+ negative_prompt_embeds_t5=None,
+ negative_prompt_embeds_llama3=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and pooled_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and pooled_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_3 is not None and prompt_embeds_t5 is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_4 is not None and prompt_embeds_llama3 is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined."
+ )
+ elif prompt is None and prompt_embeds_t5 is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined."
+ )
+ elif prompt is None and prompt_embeds_llama3 is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
+ elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)):
+ raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}")
+
+ if negative_prompt is not None and negative_pooled_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:"
+ f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:"
+ f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:"
+ f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:"
+ f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two."
+ )
+
+ if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None:
+ if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape:
+ raise ValueError(
+ "`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`"
+ f" {negative_pooled_prompt_embeds.shape}."
+ )
+ if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None:
+ if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape:
+ raise ValueError(
+ "`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`"
+ f" {negative_prompt_embeds_t5.shape}."
+ )
+ if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None:
+ if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape:
+ raise ValueError(
+ "`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`"
+ f" {negative_prompt_embeds_llama3.shape}."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ prompt_3: Optional[Union[str, List[str]]] = None,
+ prompt_4: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
+ negative_prompt_4: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds_t5: Optional[torch.FloatTensor] = None,
+ prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
+ will be used instead.
+ prompt_4 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_4` and `text_encoder_4`. If not defined, `prompt` is
+ will be used instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
+ `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders.
+ negative_prompt_4 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_4` and
+ `text_encoder_4`. If not defined, `negative_prompt` is used in all the text-encoders.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 128): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] or `tuple`:
+ [`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated. images.
+ """
+
+ prompt_embeds = kwargs.get("prompt_embeds", None)
+ negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None)
+
+ if prompt_embeds is not None:
+ deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead."
+ deprecate("prompt_embeds", "0.35.0", deprecation_message)
+ prompt_embeds_t5 = prompt_embeds[0]
+ prompt_embeds_llama3 = prompt_embeds[1]
+
+ if negative_prompt_embeds is not None:
+ deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead."
+ deprecate("negative_prompt_embeds", "0.35.0", deprecation_message)
+ negative_prompt_embeds_t5 = negative_prompt_embeds[0]
+ negative_prompt_embeds_llama3 = negative_prompt_embeds[1]
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ division = self.vae_scale_factor * 2
+ S_max = (self.default_sample_size * self.vae_scale_factor) ** 2
+ scale = S_max / (width * height)
+ scale = math.sqrt(scale)
+ width, height = int(width * scale // division * division), int(height * scale // division * division)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ prompt_3,
+ prompt_4,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ negative_prompt_4=negative_prompt_4,
+ prompt_embeds_t5=prompt_embeds_t5,
+ prompt_embeds_llama3=prompt_embeds_llama3,
+ negative_prompt_embeds_t5=negative_prompt_embeds_t5,
+ negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ elif pooled_prompt_embeds is not None:
+ batch_size = pooled_prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode prompt
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+ (
+ prompt_embeds_t5,
+ negative_prompt_embeds_t5,
+ prompt_embeds_llama3,
+ negative_prompt_embeds_llama3,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_3=prompt_3,
+ prompt_4=prompt_4,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ negative_prompt_3=negative_prompt_3,
+ negative_prompt_4=negative_prompt_4,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds_t5=prompt_embeds_t5,
+ prompt_embeds_llama3=prompt_embeds_llama3,
+ negative_prompt_embeds_t5=negative_prompt_embeds_t5,
+ negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0)
+ prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1)
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ pooled_prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ mu = calculate_shift(self.transformer.max_seq)
+ scheduler_kwargs = {"mu": mu}
+ if isinstance(self.scheduler, UniPCMultistepScheduler):
+ self.scheduler.set_timesteps(num_inference_steps, device=device) # , shift=math.exp(mu))
+ timesteps = self.scheduler.timesteps
+ else:
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timesteps=timestep,
+ encoder_hidden_states_t5=prompt_embeds_t5,
+ encoder_hidden_states_llama3=prompt_embeds_llama3,
+ pooled_embeds=pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+ noise_pred = -noise_pred
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5)
+ prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3)
+ pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return HiDreamImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/hidream_image/pipeline_output.py b/src/diffusers/pipelines/hidream_image/pipeline_output.py
new file mode 100644
index 000000000000..66f0f1260d18
--- /dev/null
+++ b/src/diffusers/pipelines/hidream_image/pipeline_output.py
@@ -0,0 +1,35 @@
+# Copyright 2025 HiDream-ai Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class HiDreamImagePipelineOutput(BaseOutput):
+ """
+ Output class for HiDreamImage pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/hunyuan_image/__init__.py b/src/diffusers/pipelines/hunyuan_image/__init__.py
new file mode 100644
index 000000000000..7da72fa12b2c
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_image/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"]
+ _import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_hunyuanimage import HunyuanImagePipeline
+ from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py b/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py
new file mode 100644
index 000000000000..658935ccd886
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py
@@ -0,0 +1,866 @@
+# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
+
+from ...guiders import AdaptiveProjectedMixGuidance
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HunyuanImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import HunyuanImagePipeline
+
+ >>> pipe = HunyuanImagePipeline.from_pretrained(
+ ... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0]
+ >>> image.save("hunyuanimage.png")
+ ```
+"""
+
+
+def extract_glyph_text(prompt: str):
+ """
+ Extract text enclosed in quotes for glyph rendering.
+
+ Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
+
+ Args:
+ prompt: Input text prompt
+
+ Returns:
+ Formatted glyph text string or None if no quoted text found
+ """
+ text_prompt_texts = []
+ pattern_quote_single = r"\'(.*?)\'"
+ pattern_quote_double = r"\"(.*?)\""
+ pattern_quote_chinese_single = r"‘(.*?)’"
+ pattern_quote_chinese_double = r"“(.*?)”"
+
+ matches_quote_single = re.findall(pattern_quote_single, prompt)
+ matches_quote_double = re.findall(pattern_quote_double, prompt)
+ matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
+ matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
+
+ text_prompt_texts.extend(matches_quote_single)
+ text_prompt_texts.extend(matches_quote_double)
+ text_prompt_texts.extend(matches_quote_chinese_single)
+ text_prompt_texts.extend(matches_quote_chinese_double)
+
+ if text_prompt_texts:
+ glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". "
+ else:
+ glyph_text_formatted = None
+
+ return glyph_text_formatted
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class HunyuanImagePipeline(DiffusionPipeline):
+ r"""
+ The HunyuanImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`HunyuanImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanImage`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
+ variant.
+ tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
+ guider ([`AdaptiveProjectedMixGuidance`]):
+ [AdaptiveProjectedMixGuidance]to be used to guide the image generation.
+ ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*):
+ [AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _optional_components = ["ocr_guider", "guider"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLHunyuanImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: ByT5Tokenizer,
+ transformer: HunyuanImageTransformer2DModel,
+ guider: Optional[AdaptiveProjectedMixGuidance] = None,
+ ocr_guider: Optional[AdaptiveProjectedMixGuidance] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ guider=guider,
+ ocr_guider=ocr_guider,
+ )
+
+ self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = 1000
+ self.tokenizer_2_max_length = 128
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 64
+
+ def _get_qwen_prompt_embeds(
+ self,
+ tokenizer: Qwen2Tokenizer,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ tokenizer_max_length: int = 1000,
+ template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
+ drop_idx: int = 34,
+ hidden_state_skip_layer: int = 2,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = tokenizer(
+ txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
+ ).to(device)
+
+ encoder_hidden_states = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
+
+ prompt_embeds = prompt_embeds[:, drop_idx:]
+ encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ encoder_attention_mask = encoder_attention_mask.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def _get_byt5_prompt_embeds(
+ self,
+ tokenizer: ByT5Tokenizer,
+ text_encoder: T5EncoderModel,
+ prompt: str,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ tokenizer_max_length: int = 128,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ if isinstance(prompt, list):
+ raise ValueError("byt5 prompt should be a string")
+ elif prompt is None:
+ raise ValueError("byt5 prompt should not be None")
+
+ txt_tokens = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer_max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ ).to(device)
+
+ prompt_embeds = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask.float(),
+ )[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ encoder_attention_mask = txt_tokens.attention_mask.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ batch_size: int = 1,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ batch_size (`int`):
+ batch size of prompts, defaults to 1
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
+ argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ """
+ device = device or self._execution_device
+
+ if prompt is None:
+ prompt = [""] * batch_size
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
+ tokenizer=self.tokenizer,
+ text_encoder=self.text_encoder,
+ prompt=prompt,
+ device=device,
+ tokenizer_max_length=self.tokenizer_max_length,
+ template=self.prompt_template_encode,
+ drop_idx=self.prompt_template_encode_start_idx,
+ )
+
+ if prompt_embeds_2 is None:
+ prompt_embeds_2_list = []
+ prompt_embeds_mask_2_list = []
+
+ glyph_texts = [extract_glyph_text(p) for p in prompt]
+ for glyph_text in glyph_texts:
+ if glyph_text is None:
+ glyph_text_embeds = torch.zeros(
+ (1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device
+ )
+ glyph_text_embeds_mask = torch.zeros(
+ (1, self.tokenizer_2_max_length), device=device, dtype=torch.int64
+ )
+ else:
+ glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds(
+ tokenizer=self.tokenizer_2,
+ text_encoder=self.text_encoder_2,
+ prompt=glyph_text,
+ device=device,
+ tokenizer_max_length=self.tokenizer_2_max_length,
+ )
+
+ prompt_embeds_2_list.append(glyph_text_embeds)
+ prompt_embeds_mask_2_list.append(glyph_text_embeds_mask)
+
+ prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0)
+ prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ _, seq_len_2, _ = prompt_embeds_2.shape
+ prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2)
+
+ return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ prompt_embeds_2=None,
+ prompt_embeds_mask_2=None,
+ negative_prompt_embeds_2=None,
+ negative_prompt_embeds_mask_2=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if prompt is None and prompt_embeds_2 is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
+ )
+
+ if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
+ raise ValueError(
+ "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
+ )
+ if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
+ raise ValueError(
+ "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ distilled_guidance_scale: Optional[float] = 3.25,
+ sigmas: Optional[List[float]] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
+ not provided, will use an empty negative prompt. Ignored when not using guidance. ).
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ distilled_guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
+ images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
+ guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
+ ignored.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, text embeddings mask will be generated from `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input
+ argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt`
+ input argument.
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt`
+ input argument.
+ negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.*
+ prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from
+ `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
+ [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds_2=prompt_embeds_2,
+ prompt_embeds_mask_2=prompt_embeds_mask_2,
+ negative_prompt_embeds_2=negative_prompt_embeds_2,
+ negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. prepare prompt embeds
+
+ prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds_2=prompt_embeds_2,
+ prompt_embeds_mask_2=prompt_embeds_mask_2,
+ )
+
+ prompt_embeds = prompt_embeds.to(self.transformer.dtype)
+ prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
+
+ # select guider
+ if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None:
+ # prompt contains ocr and pipeline has a guider for ocr
+ guider = self.ocr_guider
+ elif self.guider is not None:
+ guider = self.guider
+ # distilled model does not use guidance method, use default guider with enabled=False
+ else:
+ guider = AdaptiveProjectedMixGuidance(enabled=False)
+
+ if guider._enabled and guider.num_conditions > 1:
+ (
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ negative_prompt_embeds_2,
+ negative_prompt_embeds_mask_2,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds_2=negative_prompt_embeds_2,
+ prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
+ )
+
+ negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
+ negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size=batch_size * num_images_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance (for guidance-distilled model)
+ if self.transformer.config.guidance_embeds and distilled_guidance_scale is None:
+ raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
+
+ if self.transformer.config.guidance_embeds:
+ guidance = (
+ torch.tensor(
+ [distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
+ )
+ * 1000.0
+ )
+
+ else:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if self.transformer.config.use_meanflow:
+ if i == len(timesteps) - 1:
+ timestep_r = torch.tensor([0.0], device=device)
+ else:
+ timestep_r = timesteps[i + 1]
+ timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
+ else:
+ timestep_r = None
+
+ # Step 1: Collect model inputs needed for the guidance method
+ # conditional inputs should always be first element in the tuple
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
+ "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
+ "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
+ }
+
+ # Step 2: Update guider's internal state for this denoising step
+ guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+
+ # Step 3: Prepare batched model inputs based on the guidance method
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = guider.prepare_inputs(guider_inputs)
+ # Step 4: Run the denoiser for each batch
+ # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
+ # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
+ for guider_state_batch in guider_state:
+ guider.prepare_models(self.transformer)
+
+ # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ # e.g. "pred_cond"/"pred_uncond"
+ context_name = getattr(guider_state_batch, guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ # Run denoiser and store noise prediction in this batch
+ guider_state_batch.noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep,
+ timestep_r=timestep_r,
+ guidance=guidance,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0]
+
+ # Cleanup model (e.g., remove hooks)
+ guider.cleanup_models(self.transformer)
+
+ # Step 5: Combine predictions using the guidance method
+ # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
+ # Continuing the CFG example, the guider receives:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
+ # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
+ # ]
+ # And extracts predictions using the __guidance_identifier__:
+ # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
+ # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
+ # Then applies CFG formula:
+ # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
+ # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+ noise_pred = guider(guider_state)[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return HunyuanImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py b/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py
new file mode 100644
index 000000000000..f38f53d9a562
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py
@@ -0,0 +1,752 @@
+# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...guiders import AdaptiveProjectedMixGuidance
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HunyuanImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import HunyuanImageRefinerPipeline
+
+ >>> pipe = HunyuanImageRefinerPipeline.from_pretrained(
+ ... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> image = load_image("path/to/image.png")
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, image=image, num_inference_steps=4).images[0]
+ >>> image.save("hunyuanimage.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class HunyuanImageRefinerPipeline(DiffusionPipeline):
+ r"""
+ The HunyuanImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`HunyuanImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanImageRefiner`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _optional_components = ["guider"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLHunyuanImageRefiner,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: HunyuanImageTransformer2DModel,
+ guider: Optional[AdaptiveProjectedMixGuidance] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ guider=guider,
+ )
+
+ self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = 256
+ self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ self.prompt_template_encode_start_idx = 36
+ self.default_sample_size = 64
+ self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64
+
+ # Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ tokenizer: Qwen2Tokenizer,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ tokenizer_max_length: int = 1000,
+ template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
+ drop_idx: int = 34,
+ hidden_state_skip_layer: int = 2,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = tokenizer(
+ txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
+ ).to(device)
+
+ encoder_hidden_states = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
+
+ prompt_embeds = prompt_embeds[:, drop_idx:]
+ encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ encoder_attention_mask = encoder_attention_mask.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ device: Optional[torch.device] = None,
+ batch_size: int = 1,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ batch_size (`int`):
+ batch size of prompts, defaults to 1
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
+ argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ """
+ device = device or self._execution_device
+
+ if prompt is None:
+ prompt = [""] * batch_size
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
+ tokenizer=self.tokenizer,
+ text_encoder=self.text_encoder,
+ prompt=prompt,
+ device=device,
+ tokenizer_max_length=self.tokenizer_max_length,
+ template=self.prompt_template_encode,
+ drop_idx=self.prompt_template_encode_start_idx,
+ )
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def prepare_latents(
+ self,
+ image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ strength=0.25,
+ ):
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, 1, height, width)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ cond_latents = strength * noise + (1 - strength) * image_latents
+
+ return latents, cond_latents
+
+ @staticmethod
+ def _reorder_image_tokens(image_latents):
+ image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2)
+ batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape
+ image_latents = image_latents.permute(0, 2, 1, 3, 4)
+ image_latents = image_latents.reshape(
+ batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width
+ )
+ image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous()
+
+ return image_latents
+
+ @staticmethod
+ def _restore_image_tokens_order(latents):
+ """Restore image tokens order by splitting channels and removing first frame slice."""
+ batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape
+
+ latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W
+ latents = latents.reshape(
+ batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width
+ ) # B, F*2, C//2, H, W
+
+ latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W
+ # Remove first frame slice
+ latents = latents[:, :, 1:]
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample")
+ image_latents = self._reorder_image_tokens(image_latents)
+
+ image_latents = image_latents * self.vae.config.scaling_factor
+
+ return image_latents
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ distilled_guidance_scale: Optional[float] = 3.25,
+ image: Optional[PipelineImageInput] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 4,
+ sigmas: Optional[List[float]] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, will use an empty negative
+ prompt. Ignored when not using guidance.
+ distilled_guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
+ images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
+ guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
+ ignored.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
+ [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. process image
+ if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels:
+ image_latents = image
+ else:
+ image = self.image_processor.preprocess(image, height, width)
+ image = image.unsqueeze(2).to(device, dtype=self.vae.dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ # 3.prepare prompt embeds
+
+ if self.guider is not None:
+ guider = self.guider
+ else:
+ # distilled model does not use guidance method, use default guider with enabled=False
+ guider = AdaptiveProjectedMixGuidance(enabled=False)
+
+ requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ prompt_embeds = prompt_embeds.to(self.transformer.dtype)
+
+ if requires_unconditional_embeds:
+ (
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
+
+ # 4. Prepare latent variables
+ latents, cond_latents = self.prepare_latents(
+ image_latents=image_latents,
+ batch_size=batch_size * num_images_per_prompt,
+ num_channels_latents=self.latent_channels,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance (this pipeline only supports guidance-distilled models)
+ if distilled_guidance_scale is None:
+ raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
+ guidance = (
+ torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device)
+ * 1000.0
+ )
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # Step 1: Collect model inputs needed for the guidance method
+ # conditional inputs should always be first element in the tuple
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
+ }
+
+ # Step 2: Update guider's internal state for this denoising step
+ guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+
+ # Step 3: Prepare batched model inputs based on the guidance method
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = guider.prepare_inputs(guider_inputs)
+
+ # Step 4: Run the denoiser for each batch
+ # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
+ # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
+ for guider_state_batch in guider_state:
+ guider.prepare_models(self.transformer)
+
+ # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ # e.g. "pred_cond"/"pred_uncond"
+ context_name = getattr(guider_state_batch, guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ # Run denoiser and store noise prediction in this batch
+ guider_state_batch.noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ guidance=guidance,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0]
+
+ # Cleanup model (e.g., remove hooks)
+ guider.cleanup_models(self.transformer)
+
+ # Step 5: Combine predictions using the guidance method
+ # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
+ # Continuing the CFG example, the guider receives:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
+ # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
+ # ]
+ # And extracts predictions using the __guidance_identifier__:
+ # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
+ # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
+ # Then applies CFG formula:
+ # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
+ # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+ noise_pred = guider(guider_state)[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ latents = self._restore_image_tokens_order(latents)
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return HunyuanImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/hunyuan_image/pipeline_output.py b/src/diffusers/pipelines/hunyuan_image/pipeline_output.py
new file mode 100644
index 000000000000..1e76892a0e81
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_image/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class HunyuanImagePipelineOutput(BaseOutput):
+ """
+ Output class for HunyuanImage pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/hunyuan_video/__init__.py b/src/diffusers/pipelines/hunyuan_video/__init__.py
index d9cacad24f17..d42d38fac979 100644
--- a/src/diffusers/pipelines/hunyuan_video/__init__.py
+++ b/src/diffusers/pipelines/hunyuan_video/__init__.py
@@ -24,6 +24,7 @@
else:
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
+ _import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"]
_import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -36,6 +37,7 @@
else:
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
from .pipeline_hunyuan_video import HunyuanVideoPipeline
+ from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline
from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline
else:
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
index 297d2a9c9396..b50a6ae3ed27 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -463,6 +463,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -470,6 +476,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -478,6 +490,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -485,6 +503,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
@@ -575,13 +599,13 @@ def __call__(
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
guidance_scale (`float`, defaults to `6.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
- CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
- not applied.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. Note that the only available
+ HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and
+ conditional latent is not applied.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
index 3cb91b3782f2..5c8e295eaf4c 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -420,6 +420,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -427,6 +433,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -435,6 +447,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -442,6 +460,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
@@ -529,15 +553,14 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
- When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
+ `negative_prompt` is provided.
guidance_scale (`float`, defaults to `6.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
- CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
- not applied.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -693,28 +716,30 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- encoder_attention_mask=prompt_attention_mask,
- pooled_projections=pooled_prompt_embeds,
- guidance=guidance,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if do_true_cfg:
- neg_noise_pred = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
- encoder_attention_mask=negative_prompt_attention_mask,
- pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ pooled_projections=negative_pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py
new file mode 100644
index 000000000000..8006514f47ea
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py
@@ -0,0 +1,1138 @@
+# Copyright 2025 The Framepack Team, The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPTextModel,
+ CLIPTokenizer,
+ LlamaModel,
+ LlamaTokenizerFast,
+ SiglipImageProcessor,
+ SiglipVisionModel,
+)
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import HunyuanVideoLoraLoaderMixin
+from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoFramepackTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HunyuanVideoFramepackPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO(yiyi): We can pack the checkpoints nicely with modular loader
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ##### Image-to-Video
+
+ ```python
+ >>> import torch
+ >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
+ >>> from diffusers.utils import export_to_video, load_image
+ >>> from transformers import SiglipImageProcessor, SiglipVisionModel
+
+ >>> transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
+ ... "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
+ ... )
+ >>> feature_extractor = SiglipImageProcessor.from_pretrained(
+ ... "lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
+ ... )
+ >>> image_encoder = SiglipVisionModel.from_pretrained(
+ ... "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
+ ... )
+ >>> pipe = HunyuanVideoFramepackPipeline.from_pretrained(
+ ... "hunyuanvideo-community/HunyuanVideo",
+ ... transformer=transformer,
+ ... feature_extractor=feature_extractor,
+ ... image_encoder=image_encoder,
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> pipe.vae.enable_tiling()
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
+ ... )
+ >>> output = pipe(
+ ... image=image,
+ ... prompt="A penguin dancing in the snow",
+ ... height=832,
+ ... width=480,
+ ... num_frames=91,
+ ... num_inference_steps=30,
+ ... guidance_scale=9.0,
+ ... generator=torch.Generator().manual_seed(0),
+ ... sampling_type="inverted_anti_drifting",
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=30)
+ ```
+
+ ##### First and Last Image-to-Video
+
+ ```python
+ >>> import torch
+ >>> from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
+ >>> from diffusers.utils import export_to_video, load_image
+ >>> from transformers import SiglipImageProcessor, SiglipVisionModel
+
+ >>> transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
+ ... "lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
+ ... )
+ >>> feature_extractor = SiglipImageProcessor.from_pretrained(
+ ... "lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
+ ... )
+ >>> image_encoder = SiglipVisionModel.from_pretrained(
+ ... "lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
+ ... )
+ >>> pipe = HunyuanVideoFramepackPipeline.from_pretrained(
+ ... "hunyuanvideo-community/HunyuanVideo",
+ ... transformer=transformer,
+ ... feature_extractor=feature_extractor,
+ ... image_encoder=image_encoder,
+ ... torch_dtype=torch.float16,
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
+ >>> first_image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
+ ... )
+ >>> last_image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
+ ... )
+ >>> output = pipe(
+ ... image=first_image,
+ ... last_image=last_image,
+ ... prompt=prompt,
+ ... height=512,
+ ... width=512,
+ ... num_frames=91,
+ ... num_inference_steps=30,
+ ... guidance_scale=9.0,
+ ... generator=torch.Generator().manual_seed(0),
+ ... sampling_type="inverted_anti_drifting",
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=30)
+ ```
+"""
+
+
+DEFAULT_PROMPT_TEMPLATE = {
+ "template": (
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ ),
+ "crop_start": 95,
+}
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class FramepackSamplingType(str, Enum):
+ VANILLA = "vanilla"
+ INVERTED_ANTI_DRIFTING = "inverted_anti_drifting"
+
+
+class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using HunyuanVideo.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ text_encoder ([`LlamaModel`]):
+ [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ tokenizer (`LlamaTokenizer`):
+ Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
+ transformer ([`HunyuanVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder_2 ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ text_encoder: LlamaModel,
+ tokenizer: LlamaTokenizerFast,
+ transformer: HunyuanVideoFramepackTransformer3DModel,
+ vae: AutoencoderKLHunyuanVideo,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ image_encoder: SiglipVisionModel,
+ feature_extractor: SiglipImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds
+ def _get_llama_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_template: Dict[str, Any],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ num_hidden_layers_to_skip: int = 2,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ prompt = [prompt_template["template"].format(p) for p in prompt]
+
+ crop_start = prompt_template.get("crop_start", None)
+ if crop_start is None:
+ prompt_template_input = self.tokenizer(
+ prompt_template["template"],
+ padding="max_length",
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=False,
+ )
+ crop_start = prompt_template_input["input_ids"].shape[-1]
+ # Remove <|eot_id|> token and placeholder {}
+ crop_start -= 2
+
+ max_sequence_length += crop_start
+ text_inputs = self.tokenizer(
+ prompt,
+ max_length=max_sequence_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt",
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_attention_mask=True,
+ )
+ text_input_ids = text_inputs.input_ids.to(device=device)
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
+
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
+
+ if crop_start is not None and crop_start > 0:
+ prompt_embeds = prompt_embeds[:, crop_start:]
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 77,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]] = None,
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ max_sequence_length: int = 256,
+ ):
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
+ prompt,
+ prompt_template,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if pooled_prompt_embeds is None:
+ if prompt_2 is None:
+ prompt_2 = prompt
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt,
+ num_videos_per_prompt,
+ device=device,
+ dtype=dtype,
+ max_sequence_length=77,
+ )
+
+ return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
+
+ def encode_image(
+ self, image: torch.Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
+ ):
+ device = device or self._execution_device
+ image = (image + 1) / 2.0 # [-1, 1] -> [0, 1]
+ image = self.feature_extractor(images=image, return_tensors="pt", do_rescale=False).to(
+ device=device, dtype=self.image_encoder.dtype
+ )
+ image_embeds = self.image_encoder(**image).last_hidden_state
+ return image_embeds.to(dtype=dtype)
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_template=None,
+ image=None,
+ image_latents=None,
+ last_image=None,
+ last_image_latents=None,
+ sampling_type=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if prompt_template is not None:
+ if not isinstance(prompt_template, dict):
+ raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
+ if "template" not in prompt_template:
+ raise ValueError(
+ f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
+ )
+
+ sampling_types = [x.value for x in FramepackSamplingType.__members__.values()]
+ if sampling_type not in sampling_types:
+ raise ValueError(f"`sampling_type` has to be one of '{sampling_types}' but is '{sampling_type}'")
+
+ if image is not None and image_latents is not None:
+ raise ValueError("Only one of `image` or `image_latents` can be passed.")
+ if last_image is not None and last_image_latents is not None:
+ raise ValueError("Only one of `last_image` or `last_image_latents` can be passed.")
+ if sampling_type != FramepackSamplingType.INVERTED_ANTI_DRIFTING and (
+ last_image is not None or last_image_latents is not None
+ ):
+ raise ValueError(
+ 'Only `"inverted_anti_drifting"` inference type supports `last_image` or `last_image_latents`.'
+ )
+
+ def prepare_latents(
+ self,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def prepare_image_latents(
+ self,
+ image: torch.Tensor,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+ if latents is None:
+ image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype)
+ latents = self.vae.encode(image).latent_dist.sample(generator=generator)
+ latents = latents * self.vae.config.scaling_factor
+ return latents.to(device=device, dtype=dtype)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ last_image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Union[str, List[str]] = None,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ latent_window_size: int = 9,
+ num_inference_steps: int = 50,
+ sigmas: List[float] = None,
+ true_cfg_scale: float = 1.0,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ image_latents: Optional[torch.Tensor] = None,
+ last_image_latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
+ max_sequence_length: int = 256,
+ sampling_type: FramepackSamplingType = FramepackSamplingType.INVERTED_ANTI_DRIFTING,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
+ The image to be used as the starting point for the video generation.
+ last_image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`, *optional*):
+ The optional last image to be used as the ending point for the video generation. This is useful for
+ generating transitions between two images.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ height (`int`, defaults to `720`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1280`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `129`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. Note that the only available
+ HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and
+ conditional latent is not applied.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ image_latents (`torch.Tensor`, *optional*):
+ Pre-encoded image latents. If not provided, the image will be encoded using the VAE.
+ last_image_latents (`torch.Tensor`, *optional*):
+ Pre-encoded last image latents. If not provided, the last image will be encoded using the VAE.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideoFramepackPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideoFramepackPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideoFramepackPipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images and the second element is a list
+ of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw)
+ content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ prompt_template,
+ image,
+ image_latents,
+ last_image,
+ last_image_latents,
+ sampling_type,
+ )
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+ transformer_dtype = self.transformer.dtype
+ vae_dtype = self.vae.dtype
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
+
+ if do_true_cfg:
+ negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_template=prompt_template,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ prompt_attention_mask=negative_prompt_attention_mask,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare image
+ image = self.video_processor.preprocess(image, height, width)
+ image_embeds = self.encode_image(image, device=device).to(transformer_dtype)
+ if last_image is not None:
+ # Credits: https://github.com/lllyasviel/FramePack/pull/167
+ # Users can modify the weighting strategy applied here
+ last_image = self.video_processor.preprocess(last_image, height, width)
+ last_image_embeds = self.encode_image(last_image, device=device).to(transformer_dtype)
+ last_image_embeds = (image_embeds + last_image_embeds) / 2
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ window_num_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1
+ num_latent_sections = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
+ history_video = None
+ total_generated_latent_frames = 0
+
+ image_latents = self.prepare_image_latents(
+ image, dtype=torch.float32, device=device, generator=generator, latents=image_latents
+ )
+ if last_image is not None:
+ last_image_latents = self.prepare_image_latents(
+ last_image, dtype=torch.float32, device=device, generator=generator
+ )
+
+ # Specific to the released checkpoints:
+ # - https://huggingface.co/lllyasviel/FramePackI2V_HY
+ # - https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503
+ # TODO: find a more generic way in future if there are more checkpoints
+ if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING:
+ history_sizes = [1, 2, 16]
+ history_latents = torch.zeros(
+ batch_size,
+ num_channels_latents,
+ sum(history_sizes),
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ device=device,
+ dtype=torch.float32,
+ )
+
+ elif sampling_type == FramepackSamplingType.VANILLA:
+ history_sizes = [16, 2, 1]
+ history_latents = torch.zeros(
+ batch_size,
+ num_channels_latents,
+ sum(history_sizes),
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ device=device,
+ dtype=torch.float32,
+ )
+ history_latents = torch.cat([history_latents, image_latents], dim=2)
+ total_generated_latent_frames += 1
+
+ else:
+ assert False
+
+ # 6. Prepare guidance condition
+ guidance = torch.tensor([guidance_scale] * batch_size, dtype=transformer_dtype, device=device) * 1000.0
+
+ # 7. Denoising loop
+ for k in range(num_latent_sections):
+ if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING:
+ latent_paddings = list(reversed(range(num_latent_sections)))
+ if num_latent_sections > 4:
+ latent_paddings = [3] + [2] * (num_latent_sections - 3) + [1, 0]
+
+ is_first_section = k == 0
+ is_last_section = k == num_latent_sections - 1
+ latent_padding_size = latent_paddings[k] * latent_window_size
+
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, *history_sizes]))
+ (
+ indices_prefix,
+ indices_padding,
+ indices_latents,
+ indices_latents_history_1x,
+ indices_latents_history_2x,
+ indices_latents_history_4x,
+ ) = indices.split([1, latent_padding_size, latent_window_size, *history_sizes], dim=0)
+ # Inverted anti-drifting sampling: Figure 2(c) in the paper
+ indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
+
+ latents_prefix = image_latents
+ latents_history_1x, latents_history_2x, latents_history_4x = history_latents[
+ :, :, : sum(history_sizes)
+ ].split(history_sizes, dim=2)
+ if last_image is not None and is_first_section:
+ latents_history_1x = last_image_latents
+ latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2)
+
+ elif sampling_type == FramepackSamplingType.VANILLA:
+ indices = torch.arange(0, sum([1, *history_sizes, latent_window_size]))
+ (
+ indices_prefix,
+ indices_latents_history_4x,
+ indices_latents_history_2x,
+ indices_latents_history_1x,
+ indices_latents,
+ ) = indices.split([1, *history_sizes, latent_window_size], dim=0)
+ indices_clean_latents = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
+
+ latents_prefix = image_latents
+ latents_history_4x, latents_history_2x, latents_history_1x = history_latents[
+ :, :, -sum(history_sizes) :
+ ].split(history_sizes, dim=2)
+ latents_clean = torch.cat([latents_prefix, latents_history_1x], dim=2)
+
+ else:
+ assert False
+
+ latents = self.prepare_latents(
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ window_num_frames,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=None,
+ )
+
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ image_seq_len = (
+ latents.shape[2] * latents.shape[3] * latents.shape[4] / self.transformer.config.patch_size**2
+ )
+ exp_max = 7.0
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ mu = min(mu, math.log(exp_max))
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu
+ )
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latents.to(transformer_dtype),
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
+ image_embeds=image_embeds,
+ indices_latents=indices_latents,
+ guidance=guidance,
+ latents_clean=latents_clean.to(transformer_dtype),
+ indices_latents_clean=indices_clean_latents,
+ latents_history_2x=latents_history_2x.to(transformer_dtype),
+ indices_latents_history_2x=indices_latents_history_2x,
+ latents_history_4x=latents_history_4x.to(transformer_dtype),
+ indices_latents_history_4x=indices_latents_history_4x,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ neg_noise_pred = self.transformer(
+ hidden_states=latents.to(transformer_dtype),
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ pooled_projections=negative_pooled_prompt_embeds,
+ image_embeds=image_embeds,
+ indices_latents=indices_latents,
+ guidance=guidance,
+ latents_clean=latents_clean.to(transformer_dtype),
+ indices_latents_clean=indices_clean_latents,
+ latents_history_2x=latents_history_2x.to(transformer_dtype),
+ indices_latents_history_2x=indices_latents_history_2x,
+ latents_history_4x=latents_history_4x.to(transformer_dtype),
+ indices_latents_history_4x=indices_latents_history_4x,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred.float(), t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING:
+ if is_last_section:
+ latents = torch.cat([image_latents, latents], dim=2)
+ total_generated_latent_frames += latents.shape[2]
+ history_latents = torch.cat([latents, history_latents], dim=2)
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames]
+ section_latent_frames = (
+ (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
+ )
+ index_slice = (slice(None), slice(None), slice(0, section_latent_frames))
+
+ elif sampling_type == FramepackSamplingType.VANILLA:
+ total_generated_latent_frames += latents.shape[2]
+ history_latents = torch.cat([history_latents, latents], dim=2)
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:]
+ section_latent_frames = latent_window_size * 2
+ index_slice = (slice(None), slice(None), slice(-section_latent_frames, None))
+
+ else:
+ assert False
+
+ if history_video is None:
+ if not output_type == "latent":
+ current_latents = real_history_latents.to(vae_dtype) / self.vae.config.scaling_factor
+ history_video = self.vae.decode(current_latents, return_dict=False)[0]
+ else:
+ history_video = [real_history_latents]
+ else:
+ if not output_type == "latent":
+ overlapped_frames = (latent_window_size - 1) * self.vae_scale_factor_temporal + 1
+ current_latents = (
+ real_history_latents[index_slice].to(vae_dtype) / self.vae.config.scaling_factor
+ )
+ current_video = self.vae.decode(current_latents, return_dict=False)[0]
+
+ if sampling_type == FramepackSamplingType.INVERTED_ANTI_DRIFTING:
+ history_video = self._soft_append(current_video, history_video, overlapped_frames)
+ elif sampling_type == FramepackSamplingType.VANILLA:
+ history_video = self._soft_append(history_video, current_video, overlapped_frames)
+ else:
+ assert False
+ else:
+ history_video.append(real_history_latents)
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ generated_frames = history_video.size(2)
+ generated_frames = (
+ generated_frames - 1
+ ) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ history_video = history_video[:, :, :generated_frames]
+ video = self.video_processor.postprocess_video(history_video, output_type=output_type)
+ else:
+ video = history_video
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HunyuanVideoFramepackPipelineOutput(frames=video)
+
+ def _soft_append(self, history: torch.Tensor, current: torch.Tensor, overlap: int = 0):
+ if overlap <= 0:
+ return torch.cat([history, current], dim=2)
+
+ assert history.shape[2] >= overlap, f"Current length ({history.shape[2]}) must be >= overlap ({overlap})"
+ assert current.shape[2] >= overlap, f"History length ({current.shape[2]}) must be >= overlap ({overlap})"
+
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
+
+ return output.to(history)
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
index 774b72e6c7c1..aa04e6509730 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,7 +30,7 @@
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -100,6 +100,50 @@
}
+def _expand_input_ids_with_image_tokens(
+ text_input_ids,
+ prompt_attention_mask,
+ max_sequence_length,
+ image_token_index,
+ image_emb_len,
+ image_emb_start,
+ image_emb_end,
+ pad_token_id,
+):
+ special_image_token_mask = text_input_ids == image_token_index
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
+ batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
+
+ max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
+ new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
+
+ expanded_input_ids = torch.full(
+ (text_input_ids.shape[0], max_expanded_length),
+ pad_token_id,
+ dtype=text_input_ids.dtype,
+ device=text_input_ids.device,
+ )
+ expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
+ expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
+
+ expanded_attention_mask = torch.zeros(
+ (text_input_ids.shape[0], max_expanded_length),
+ dtype=prompt_attention_mask.dtype,
+ device=prompt_attention_mask.device,
+ )
+ attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
+ expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
+ expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
+ position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
+
+ return {
+ "input_ids": expanded_input_ids,
+ "attention_mask": expanded_attention_mask,
+ "position_ids": position_ids,
+ }
+
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -251,6 +295,12 @@ def _get_llama_prompt_embeds(
prompt = [prompt_template["template"].format(p) for p in prompt]
crop_start = prompt_template.get("crop_start", None)
+
+ image_emb_len = prompt_template.get("image_emb_len", 576)
+ image_emb_start = prompt_template.get("image_emb_start", 5)
+ image_emb_end = prompt_template.get("image_emb_end", 581)
+ double_return_token_id = prompt_template.get("double_return_token_id", 271)
+
if crop_start is None:
prompt_template_input = self.tokenizer(
prompt_template["template"],
@@ -280,19 +330,25 @@ def _get_llama_prompt_embeds(
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
+ image_token_index = self.text_encoder.config.image_token_index
+ pad_token_id = self.text_encoder.config.pad_token_id
+ expanded_inputs = _expand_input_ids_with_image_tokens(
+ text_input_ids,
+ prompt_attention_mask,
+ max_sequence_length,
+ image_token_index,
+ image_emb_len,
+ image_emb_start,
+ image_emb_end,
+ pad_token_id,
+ )
prompt_embeds = self.text_encoder(
- input_ids=text_input_ids,
- attention_mask=prompt_attention_mask,
+ **expanded_inputs,
pixel_values=image_embeds,
output_hidden_states=True,
).hidden_states[-(num_hidden_layers_to_skip + 1)]
prompt_embeds = prompt_embeds.to(dtype=dtype)
- image_emb_len = prompt_template.get("image_emb_len", 576)
- image_emb_start = prompt_template.get("image_emb_start", 5)
- image_emb_end = prompt_template.get("image_emb_end", 581)
- double_return_token_id = prompt_template.get("double_return_token_id", 271)
-
if crop_start is not None and crop_start > 0:
text_crop_start = crop_start - 1 + image_emb_len
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
@@ -542,6 +598,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -549,6 +611,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -557,6 +625,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -564,6 +638,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
@@ -655,13 +735,13 @@ def __call__(
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
guidance_scale (`float`, defaults to `1.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
- CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
- not applied.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. Note that the only available
+ HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and
+ conditional latent is not applied.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py
index c5cb853e3932..fae0370a53b7 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_output.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_output.py
@@ -1,5 +1,8 @@
from dataclasses import dataclass
+from typing import List, Union
+import numpy as np
+import PIL.Image
import torch
from diffusers.utils import BaseOutput
@@ -18,3 +21,19 @@ class HunyuanVideoPipelineOutput(BaseOutput):
"""
frames: torch.Tensor
+
+
+@dataclass
+class HunyuanVideoFramepackPipelineOutput(BaseOutput):
+ r"""
+ Output class for HunyuanVideo pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor
+ corresponds to a latent that decodes to multiple frames.
+ """
+
+ frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]], List[torch.Tensor]]
diff --git a/src/diffusers/pipelines/hunyuan_video1_5/__init__.py b/src/diffusers/pipelines/hunyuan_video1_5/__init__.py
new file mode 100644
index 000000000000..846320f4ace0
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video1_5/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"]
+ _import_structure["pipeline_hunyuan_video1_5_image2video"] = ["HunyuanVideo15ImageToVideoPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline
+ from .pipeline_hunyuan_video1_5_image2video import HunyuanVideo15ImageToVideoPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py
new file mode 100644
index 000000000000..82817365b6a5
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py
@@ -0,0 +1,103 @@
+# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...configuration_utils import register_to_config
+from ...video_processor import VideoProcessor
+
+
+# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L20
+def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
+ num_patches = round((base_size / patch_size) ** 2)
+ assert max_ratio >= 1.0
+ crop_size_list = []
+ wp, hp = num_patches, 1
+ while wp > 0:
+ if max(wp, hp) / min(wp, hp) <= max_ratio:
+ crop_size_list.append((wp * patch_size, hp * patch_size))
+ if (hp + 1) * wp <= num_patches:
+ hp += 1
+ else:
+ wp -= 1
+ return crop_size_list
+
+
+# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38
+def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
+ """
+ Get the closest ratio in the buckets.
+
+ Args:
+ height (float): video height
+ width (float): video width
+ ratios (list): video aspect ratio
+ buckets (list): buckets generated by `generate_crop_size_list`
+
+ Returns:
+ the closest size in the buckets and the corresponding ratio
+ """
+ aspect_ratio = float(height) / float(width)
+ diff_ratios = ratios - aspect_ratio
+
+ if aspect_ratio >= 1:
+ indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
+ else:
+ indices = [(index, x) for index, x in enumerate(diff_ratios) if x >= 0]
+
+ closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
+ closest_size = buckets[closest_ratio_id]
+ closest_ratio = ratios[closest_ratio_id]
+
+ return closest_size, closest_ratio
+
+
+class HunyuanVideo15ImageProcessor(VideoProcessor):
+ r"""
+ Image/video processor to preproces/postprocess the reference image/generatedvideo for the HunyuanVideo1.5 model.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
+ vae_scale_factor (`int`, *optional*, defaults to `16`):
+ VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
+ this factor.
+ vae_latent_channels (`int`, *optional*, defaults to `32`):
+ VAE latent channels.
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
+ Whether to convert the image to RGB.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 16,
+ vae_latent_channels: int = 32,
+ do_convert_rgb: bool = True,
+ ):
+ super().__init__(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ do_convert_rgb=do_convert_rgb,
+ )
+
+ def calculate_default_height_width(self, height: int, width: int, target_size: int):
+ crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.config.vae_scale_factor)
+ aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
+ height, width = get_closest_ratio(height, width, aspect_ratios, crop_size_list)[0]
+
+ return height, width
diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py
new file mode 100644
index 000000000000..00a703939004
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py
@@ -0,0 +1,837 @@
+# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel
+
+from ...guiders import ClassifierFreeGuidance
+from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .image_processor import HunyuanVideo15ImageProcessor
+from .pipeline_output import HunyuanVideo15PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import HunyuanVideo15Pipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_t2v"
+ >>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+ >>> pipe.vae.enable_tiling()
+ >>> pipe.to("cuda")
+
+ >>> output = pipe(
+ ... prompt="A cat walks on the grass, realistic",
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=15)
+ ```
+"""
+
+
+def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]:
+ """
+ Apply text to template.
+
+ Args:
+ prompt (List[str]): Input text.
+ system_message (str): System message.
+
+ Returns:
+ List[Dict[str, Any]]: List of chat conversation.
+ """
+
+ template = [
+ [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt
+ ]
+
+ return template
+
+
+def extract_glyph_texts(prompt: str) -> List[str]:
+ """
+ Extract glyph texts from prompt using regex pattern.
+
+ Args:
+ prompt: Input prompt string
+
+ Returns:
+ List of extracted glyph texts
+ """
+ pattern = r"\"(.*?)\"|“(.*?)”"
+ matches = re.findall(pattern, prompt)
+ result = [match[0] or match[1] for match in matches]
+ result = list(dict.fromkeys(result)) if len(result) > 1 else result
+
+ if result:
+ formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". "
+ else:
+ formatted_result = None
+
+ return formatted_result
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class HunyuanVideo15Pipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using HunyuanVideo1.5.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`HunyuanVideo15Transformer3DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ vae ([`AutoencoderKLHunyuanVideo15`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
+ variant.
+ tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
+ guider ([`ClassifierFreeGuidance`]):
+ [ClassifierFreeGuidance]for classifier free guidance.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ text_encoder: Qwen2_5_VLTextModel,
+ tokenizer: Qwen2Tokenizer,
+ transformer: HunyuanVideo15Transformer3DModel,
+ vae: AutoencoderKLHunyuanVideo15,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: ByT5Tokenizer,
+ guider: ClassifierFreeGuidance,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ guider=guider,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16
+ self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640
+ self.vision_states_dim = (
+ self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152
+ )
+ self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32
+ # fmt: off
+ self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \
+ 1. The main content and theme of the video. \
+ 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
+ 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
+ 4. background environment, light, style and atmosphere. \
+ 5. camera angles, movements, and transitions used in the video."
+ # fmt: on
+ self.prompt_template_encode_start_idx = 108
+ self.tokenizer_max_length = 1000
+ self.tokenizer_2_max_length = 256
+ self.vision_num_semantic_tokens = 729
+ self.default_aspect_ratio = (16, 9) # (width: height)
+
+ @staticmethod
+ def _get_mllm_prompt_embeds(
+ text_encoder: Qwen2_5_VLTextModel,
+ tokenizer: Qwen2Tokenizer,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ tokenizer_max_length: int = 1000,
+ num_hidden_layers_to_skip: int = 2,
+ # fmt: off
+ system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \
+ 1. The main content and theme of the video. \
+ 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
+ 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
+ 4. background environment, light, style and atmosphere. \
+ 5. camera angles, movements, and transitions used in the video.",
+ # fmt: on
+ crop_start: int = 108,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ prompt = format_text_input(prompt, system_message)
+
+ text_inputs = tokenizer.apply_chat_template(
+ prompt,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ padding="max_length",
+ max_length=tokenizer_max_length + crop_start,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device=device)
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
+
+ prompt_embeds = text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
+
+ if crop_start is not None and crop_start > 0:
+ prompt_embeds = prompt_embeds[:, crop_start:]
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
+
+ return prompt_embeds, prompt_attention_mask
+
+ @staticmethod
+ def _get_byt5_prompt_embeds(
+ tokenizer: ByT5Tokenizer,
+ text_encoder: T5EncoderModel,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ tokenizer_max_length: int = 256,
+ ):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ glyph_texts = [extract_glyph_texts(p) for p in prompt]
+
+ prompt_embeds_list = []
+ prompt_embeds_mask_list = []
+
+ for glyph_text in glyph_texts:
+ if glyph_text is None:
+ glyph_text_embeds = torch.zeros(
+ (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype
+ )
+ glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64)
+ else:
+ txt_tokens = tokenizer(
+ glyph_text,
+ padding="max_length",
+ max_length=tokenizer_max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ ).to(device)
+
+ glyph_text_embeds = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask.float(),
+ )[0]
+ glyph_text_embeds = glyph_text_embeds.to(device=device)
+ glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device)
+
+ prompt_embeds_list.append(glyph_text_embeds)
+ prompt_embeds_mask_list.append(glyph_text_embeds_mask)
+
+ prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
+ prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ batch_size: int = 1,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ batch_size (`int`):
+ batch size of prompts, defaults to 1
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
+ argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ if prompt is None:
+ prompt = [""] * batch_size
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds(
+ tokenizer=self.tokenizer,
+ text_encoder=self.text_encoder,
+ prompt=prompt,
+ device=device,
+ tokenizer_max_length=self.tokenizer_max_length,
+ system_message=self.system_message,
+ crop_start=self.prompt_template_encode_start_idx,
+ )
+
+ if prompt_embeds_2 is None:
+ prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds(
+ tokenizer=self.tokenizer_2,
+ text_encoder=self.text_encoder_2,
+ prompt=prompt,
+ device=device,
+ tokenizer_max_length=self.tokenizer_2_max_length,
+ )
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len)
+
+ _, seq_len_2, _ = prompt_embeds_2.shape
+ prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2)
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device)
+ prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ prompt_embeds_2=None,
+ prompt_embeds_mask_2=None,
+ negative_prompt_embeds_2=None,
+ negative_prompt_embeds_mask_2=None,
+ ):
+ if height is None and width is not None:
+ raise ValueError("If `width` is provided, `height` also have to be provided.")
+ elif width is None and height is not None:
+ raise ValueError("If `height` is provided, `width` also have to be provided.")
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if prompt is None and prompt_embeds_2 is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
+ )
+
+ if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
+ raise ValueError(
+ "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
+ )
+ if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
+ raise ValueError(
+ "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 32,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def prepare_cond_latents_and_mask(self, latents, dtype: Optional[torch.dtype], device: Optional[torch.device]):
+ """
+ Prepare conditional latents and mask for t2v generation.
+
+ Args:
+ latents: Main latents tensor (B, C, F, H, W)
+
+ Returns:
+ tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v
+ """
+ batch, channels, frames, height, width = latents.shape
+
+ cond_latents_concat = torch.zeros(batch, channels, frames, height, width, dtype=dtype, device=device)
+
+ mask_concat = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device)
+
+ return cond_latents_concat, mask_concat
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_frames: int = 121,
+ num_inference_steps: int = 50,
+ sigmas: List[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead.
+ height (`int`, *optional*):
+ The height in pixels of the generated video.
+ width (`int`, *optional*):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `121`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated mask for prompt embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated mask for negative prompt embeddings.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated mask for prompt embeddings from the second text encoder.
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings from the second text encoder.
+ negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated mask for negative prompt embeddings from the second text encoder.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated video. Choose between "np", "pt", or "latent".
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideo15PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated videos.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ prompt_embeds_2=prompt_embeds_2,
+ prompt_embeds_mask_2=prompt_embeds_mask_2,
+ negative_prompt_embeds_2=negative_prompt_embeds_2,
+ negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
+ )
+
+ if height is None and width is None:
+ height, width = self.video_processor.calculate_default_height_width(
+ self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ dtype=self.transformer.dtype,
+ batch_size=batch_size,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ prompt_embeds_2=prompt_embeds_2,
+ prompt_embeds_mask_2=prompt_embeds_mask_2,
+ )
+
+ if self.guider._enabled and self.guider.num_conditions > 1:
+ (
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ negative_prompt_embeds_2,
+ negative_prompt_embeds_mask_2,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ device=device,
+ dtype=self.transformer.dtype,
+ batch_size=batch_size,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ prompt_embeds_2=negative_prompt_embeds_2,
+ prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ # 5. Prepare latent variables
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ self.num_channels_latents,
+ height,
+ width,
+ num_frames,
+ self.transformer.dtype,
+ device,
+ generator,
+ latents,
+ )
+ cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, self.transformer.dtype, device)
+ image_embeds = torch.zeros(
+ batch_size,
+ self.vision_num_semantic_tokens,
+ self.vision_states_dim,
+ dtype=self.transformer.dtype,
+ device=device,
+ )
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
+
+ # Step 1: Collect model inputs needed for the guidance method
+ # conditional inputs should always be first element in the tuple
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
+ "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
+ "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
+ }
+
+ # Step 2: Update guider's internal state for this denoising step
+ self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+
+ # Step 3: Prepare batched model inputs based on the guidance method
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = self.guider.prepare_inputs(guider_inputs)
+ # Step 4: Run the denoiser for each batch
+ # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
+ # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
+ for guider_state_batch in guider_state:
+ self.guider.prepare_models(self.transformer)
+
+ # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ # e.g. "pred_cond"/"pred_uncond"
+ context_name = getattr(guider_state_batch, self.guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ # Run denoiser and store noise prediction in this batch
+ guider_state_batch.noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ image_embeds=image_embeds,
+ timestep=timestep,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0]
+
+ # Cleanup model (e.g., remove hooks)
+ self.guider.cleanup_models(self.transformer)
+
+ # Step 5: Combine predictions using the guidance method
+ # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
+ # Continuing the CFG example, the guider receives:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
+ # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
+ # ]
+ # And extracts predictions using the __guidance_identifier__:
+ # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
+ # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
+ # Then applies CFG formula:
+ # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
+ # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+ noise_pred = self.guider(guider_state)[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ # 8. decode the latents to video and postprocess
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HunyuanVideo15PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py
new file mode 100644
index 000000000000..8c555eabba11
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py
@@ -0,0 +1,960 @@
+# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+from transformers import (
+ ByT5Tokenizer,
+ Qwen2_5_VLTextModel,
+ Qwen2Tokenizer,
+ SiglipImageProcessor,
+ SiglipVisionModel,
+ T5EncoderModel,
+)
+
+from ...guiders import ClassifierFreeGuidance
+from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .image_processor import HunyuanVideo15ImageProcessor
+from .pipeline_output import HunyuanVideo15PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import HunyuanVideo15ImageToVideoPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_i2v"
+ >>> pipe = HunyuanVideo15ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
+ >>> pipe.vae.enable_tiling()
+ >>> pipe.to("cuda")
+
+ >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG")
+
+ >>> output = pipe(
+ ... prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
+ ... image=image,
+ ... num_inference_steps=50,
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=24)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input
+def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]:
+ """
+ Apply text to template.
+
+ Args:
+ prompt (List[str]): Input text.
+ system_message (str): System message.
+
+ Returns:
+ List[Dict[str, Any]]: List of chat conversation.
+ """
+
+ template = [
+ [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt
+ ]
+
+ return template
+
+
+# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts
+def extract_glyph_texts(prompt: str) -> List[str]:
+ """
+ Extract glyph texts from prompt using regex pattern.
+
+ Args:
+ prompt: Input prompt string
+
+ Returns:
+ List of extracted glyph texts
+ """
+ pattern = r"\"(.*?)\"|“(.*?)”"
+ matches = re.findall(pattern, prompt)
+ result = [match[0] or match[1] for match in matches]
+ result = list(dict.fromkeys(result)) if len(result) > 1 else result
+
+ if result:
+ formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". "
+ else:
+ formatted_result = None
+
+ return formatted_result
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for image-to-video generation using HunyuanVideo1.5.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`HunyuanVideo15Transformer3DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded video latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ vae ([`AutoencoderKLHunyuanVideo15`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
+ variant.
+ tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
+ guider ([`ClassifierFreeGuidance`]):
+ [ClassifierFreeGuidance]for classifier free guidance.
+ image_encoder ([`SiglipVisionModel`]):
+ [SiglipVisionModel](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipVisionModel)
+ variant.
+ feature_extractor ([`SiglipImageProcessor`]):
+ [SiglipImageProcessor](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipImageProcessor)
+ variant.
+ """
+
+ model_cpu_offload_seq = "image_encoder->text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ text_encoder: Qwen2_5_VLTextModel,
+ tokenizer: Qwen2Tokenizer,
+ transformer: HunyuanVideo15Transformer3DModel,
+ vae: AutoencoderKLHunyuanVideo15,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: ByT5Tokenizer,
+ guider: ClassifierFreeGuidance,
+ image_encoder: SiglipVisionModel,
+ feature_extractor: SiglipImageProcessor,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ guider=guider,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16
+ self.video_processor = HunyuanVideo15ImageProcessor(
+ vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True
+ )
+ self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640
+ self.vision_states_dim = (
+ self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152
+ )
+ self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32
+ # fmt: off
+ self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \
+ 1. The main content and theme of the video. \
+ 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
+ 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
+ 4. background environment, light, style and atmosphere. \
+ 5. camera angles, movements, and transitions used in the video."
+ # fmt: on
+ self.prompt_template_encode_start_idx = 108
+ self.tokenizer_max_length = 1000
+ self.tokenizer_2_max_length = 256
+ self.vision_num_semantic_tokens = 729
+
+ @staticmethod
+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds
+ def _get_mllm_prompt_embeds(
+ text_encoder: Qwen2_5_VLTextModel,
+ tokenizer: Qwen2Tokenizer,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ tokenizer_max_length: int = 1000,
+ num_hidden_layers_to_skip: int = 2,
+ # fmt: off
+ system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \
+ 1. The main content and theme of the video. \
+ 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \
+ 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \
+ 4. background environment, light, style and atmosphere. \
+ 5. camera angles, movements, and transitions used in the video.",
+ # fmt: on
+ crop_start: int = 108,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ prompt = format_text_input(prompt, system_message)
+
+ text_inputs = tokenizer.apply_chat_template(
+ prompt,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ padding="max_length",
+ max_length=tokenizer_max_length + crop_start,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device=device)
+ prompt_attention_mask = text_inputs.attention_mask.to(device=device)
+
+ prompt_embeds = text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_attention_mask,
+ output_hidden_states=True,
+ ).hidden_states[-(num_hidden_layers_to_skip + 1)]
+
+ if crop_start is not None and crop_start > 0:
+ prompt_embeds = prompt_embeds[:, crop_start:]
+ prompt_attention_mask = prompt_attention_mask[:, crop_start:]
+
+ return prompt_embeds, prompt_attention_mask
+
+ @staticmethod
+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds
+ def _get_byt5_prompt_embeds(
+ tokenizer: ByT5Tokenizer,
+ text_encoder: T5EncoderModel,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ tokenizer_max_length: int = 256,
+ ):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ glyph_texts = [extract_glyph_texts(p) for p in prompt]
+
+ prompt_embeds_list = []
+ prompt_embeds_mask_list = []
+
+ for glyph_text in glyph_texts:
+ if glyph_text is None:
+ glyph_text_embeds = torch.zeros(
+ (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype
+ )
+ glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64)
+ else:
+ txt_tokens = tokenizer(
+ glyph_text,
+ padding="max_length",
+ max_length=tokenizer_max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ ).to(device)
+
+ glyph_text_embeds = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask.float(),
+ )[0]
+ glyph_text_embeds = glyph_text_embeds.to(device=device)
+ glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device)
+
+ prompt_embeds_list.append(glyph_text_embeds)
+ prompt_embeds_mask_list.append(glyph_text_embeds_mask)
+
+ prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
+ prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ @staticmethod
+ def _get_image_latents(
+ vae: AutoencoderKLHunyuanVideo15,
+ image_processor: HunyuanVideo15ImageProcessor,
+ image: PIL.Image.Image,
+ height: int,
+ width: int,
+ device: torch.device,
+ ) -> torch.Tensor:
+ vae_dtype = vae.dtype
+ image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype)
+ image_tensor = image_tensor.unsqueeze(2)
+ image_latents = retrieve_latents(vae.encode(image_tensor), sample_mode="argmax")
+ image_latents = image_latents * vae.config.scaling_factor
+ return image_latents
+
+ @staticmethod
+ def _get_image_embeds(
+ image_encoder: SiglipVisionModel,
+ feature_extractor: SiglipImageProcessor,
+ image: PIL.Image.Image,
+ device: torch.device,
+ ) -> torch.Tensor:
+ image_encoder_dtype = next(image_encoder.parameters()).dtype
+ image = feature_extractor.preprocess(images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True)
+ image = image.to(device=device, dtype=image_encoder_dtype)
+ image_enc_hidden_states = image_encoder(**image).last_hidden_state
+
+ return image_enc_hidden_states
+
+ def encode_image(
+ self,
+ image: PIL.Image.Image,
+ batch_size: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> torch.Tensor:
+ image_embeds = self._get_image_embeds(
+ image_encoder=self.image_encoder,
+ feature_extractor=self.feature_extractor,
+ image=image,
+ device=device,
+ )
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(device=device, dtype=dtype)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ batch_size: int = 1,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ batch_size (`int`):
+ batch size of prompts, defaults to 1
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
+ argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ if prompt is None:
+ prompt = [""] * batch_size
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds(
+ tokenizer=self.tokenizer,
+ text_encoder=self.text_encoder,
+ prompt=prompt,
+ device=device,
+ tokenizer_max_length=self.tokenizer_max_length,
+ system_message=self.system_message,
+ crop_start=self.prompt_template_encode_start_idx,
+ )
+
+ if prompt_embeds_2 is None:
+ prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds(
+ tokenizer=self.tokenizer_2,
+ text_encoder=self.text_encoder_2,
+ prompt=prompt,
+ device=device,
+ tokenizer_max_length=self.tokenizer_2_max_length,
+ )
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len)
+
+ _, seq_len_2, _ = prompt_embeds_2.shape
+ prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2)
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device)
+ prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
+
+ def check_inputs(
+ self,
+ prompt,
+ image: PIL.Image.Image,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ prompt_embeds_2=None,
+ prompt_embeds_mask_2=None,
+ negative_prompt_embeds_2=None,
+ negative_prompt_embeds_mask_2=None,
+ ):
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `PIL.Image.Image` but is {type(image)}")
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if prompt is None and prompt_embeds_2 is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
+ )
+
+ if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
+ raise ValueError(
+ "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
+ )
+ if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
+ raise ValueError(
+ "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
+ )
+
+ # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 32,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 129,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ def prepare_cond_latents_and_mask(
+ self,
+ latents: torch.Tensor,
+ image: PIL.Image.Image,
+ batch_size: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ ):
+ """
+ Prepare conditional latents and mask for t2v generation.
+
+ Args:
+ latents: Main latents tensor (B, C, F, H, W)
+
+ Returns:
+ tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v
+ """
+
+ batch, channels, frames, height, width = latents.shape
+
+ image_latents = self._get_image_latents(
+ vae=self.vae,
+ image_processor=self.video_processor,
+ image=image,
+ height=height,
+ width=width,
+ device=device,
+ )
+
+ latent_condition = image_latents.repeat(batch_size, 1, frames, 1, 1)
+ latent_condition[:, :, 1:, :, :] = 0
+ latent_condition = latent_condition.to(device=device, dtype=dtype)
+
+ latent_mask = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device)
+ latent_mask[:, :, 0, :, :] = 1.0
+
+ return latent_condition, latent_mask
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PIL.Image.Image,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ num_frames: int = 121,
+ num_inference_steps: int = 50,
+ sigmas: List[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PIL.Image.Image`):
+ The input image to condition video generation on.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead.
+ num_frames (`int`, defaults to `121`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated mask for prompt embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated mask for negative prompt embeddings.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated mask for prompt embeddings from the second text encoder.
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings from the second text encoder.
+ negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated mask for negative prompt embeddings from the second text encoder.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated video. Choose between "np", "pt", or "latent".
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ Returns:
+ [`~HunyuanVideo15PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated videos.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ image=image,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ prompt_embeds_2=prompt_embeds_2,
+ prompt_embeds_mask_2=prompt_embeds_mask_2,
+ negative_prompt_embeds_2=negative_prompt_embeds_2,
+ negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
+ )
+
+ height, width = self.video_processor.calculate_default_height_width(
+ height=image.size[1], width=image.size[0], target_size=self.target_size
+ )
+ image = self.video_processor.resize(image, height=height, width=width, resize_mode="crop")
+
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode image
+ image_embeds = self.encode_image(
+ image=image,
+ batch_size=batch_size * num_videos_per_prompt,
+ device=device,
+ dtype=self.transformer.dtype,
+ )
+
+ # 4. Encode input prompt
+ prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ dtype=self.transformer.dtype,
+ batch_size=batch_size,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ prompt_embeds_2=prompt_embeds_2,
+ prompt_embeds_mask_2=prompt_embeds_mask_2,
+ )
+
+ if self.guider._enabled and self.guider.num_conditions > 1:
+ (
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ negative_prompt_embeds_2,
+ negative_prompt_embeds_mask_2,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ device=device,
+ dtype=self.transformer.dtype,
+ batch_size=batch_size,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ prompt_embeds_2=negative_prompt_embeds_2,
+ prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ # 6. Prepare latent variables
+ latents = self.prepare_latents(
+ batch_size=batch_size * num_videos_per_prompt,
+ num_channels_latents=self.num_channels_latents,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ dtype=self.transformer.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(
+ latents=latents,
+ image=image,
+ batch_size=batch_size * num_videos_per_prompt,
+ height=height,
+ width=width,
+ dtype=self.transformer.dtype,
+ device=device,
+ )
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
+
+ if self.transformer.config.use_meanflow:
+ if i == len(timesteps) - 1:
+ timestep_r = torch.tensor([0.0], device=device)
+ else:
+ timestep_r = timesteps[i + 1]
+ timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
+ else:
+ timestep_r = None
+
+ # Step 1: Collect model inputs needed for the guidance method
+ # conditional inputs should always be first element in the tuple
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
+ "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
+ "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
+ }
+
+ # Step 2: Update guider's internal state for this denoising step
+ self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+
+ # Step 3: Prepare batched model inputs based on the guidance method
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = self.guider.prepare_inputs(guider_inputs)
+ # Step 4: Run the denoiser for each batch
+ # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
+ # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
+ for guider_state_batch in guider_state:
+ self.guider.prepare_models(self.transformer)
+
+ # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ # e.g. "pred_cond"/"pred_uncond"
+ context_name = getattr(guider_state_batch, self.guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ # Run denoiser and store noise prediction in this batch
+ guider_state_batch.noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ image_embeds=image_embeds,
+ timestep=timestep,
+ timestep_r=timestep_r,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0]
+
+ # Cleanup model (e.g., remove hooks)
+ self.guider.cleanup_models(self.transformer)
+
+ # Step 5: Combine predictions using the guidance method
+ # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
+ # Continuing the CFG example, the guider receives:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
+ # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
+ # ]
+ # And extracts predictions using the __guidance_identifier__:
+ # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
+ # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
+ # Then applies CFG formula:
+ # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
+ # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+ noise_pred = self.guider(guider_state)[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return HunyuanVideo15PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py
new file mode 100644
index 000000000000..441164db5a09
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class HunyuanVideo15PipelineOutput(BaseOutput):
+ r"""
+ Output class for HunyuanVideo1.5 pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
index febf2b0392cc..e2f935aaf4b9 100644
--- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
+++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,11 +27,7 @@
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -128,7 +124,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -433,7 +429,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -555,7 +551,7 @@ def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -625,8 +621,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -662,7 +658,7 @@ def __call__(
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
- Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
@@ -865,7 +861,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
index 58d65a190d5b..c6cc724a71f0 100644
--- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
+++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
@@ -97,9 +97,11 @@ class I2VGenXLPipelineOutput(BaseOutput):
class I2VGenXLPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
StableDiffusionMixin,
):
+ _last_supported_version = "0.33.1"
r"""
Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/).
@@ -151,7 +153,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -384,7 +386,7 @@ def decode_latents(self, latents, decode_chunk_size=None):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -462,7 +464,7 @@ def prepare_image_latents(
image_latents = image_latents.unsqueeze(2)
# Append a position mask for each subsequent frame
- # after the intial image latent frame
+ # after the initial image latent frame
frame_position_mask = []
for frame_idx in range(num_frames - 1):
scale = (frame_idx + 1) / (num_frames - 1)
@@ -557,8 +559,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
num_videos_per_prompt (`int`, *optional*):
The number of images to generate per prompt.
decode_chunk_size (`int`, *optional*):
@@ -614,7 +616,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
self._guidance_scale = guidance_scale
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index b5f4acf5c05a..33529f5d0954 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,11 +21,7 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler, DDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
@@ -278,11 +274,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -291,7 +287,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index e653b8266f19..7286bcbee17b 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -193,7 +193,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
@@ -251,27 +251,27 @@ def __call__(
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -360,7 +360,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
"""
_load_connected_pipes = True
- model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
+ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
_exclude_from_cpu_offload = ["prior_prior"]
def __init__(
@@ -411,7 +411,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -482,27 +482,27 @@ def __call__(
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -652,7 +652,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -722,27 +722,27 @@ def __call__(
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index 5d56efef9287..f5e41d499dc3 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,21 +14,16 @@
from typing import Callable, List, Optional, Union
-import numpy as np
import PIL.Image
import torch
-from PIL import Image
from transformers import (
XLMRobertaTokenizer,
)
+from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
@@ -95,15 +90,6 @@ def get_new_h_w(h, w, scale_factor=8):
return new_h * scale_factor, new_w * scale_factor
-def prepare_image(pil_image, w=512, h=512):
- pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
- arr = np.array(pil_image.convert("RGB"))
- arr = arr.astype(np.float32) / 127.5 - 1
- arr = np.transpose(arr, [2, 0, 1])
- image = torch.from_numpy(arr).unsqueeze(0)
- return image
-
-
class KandinskyImg2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
@@ -143,7 +129,16 @@ def __init__(
scheduler=scheduler,
movq=movq,
)
- self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+ self.movq_scale_factor = (
+ 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
+ )
+ movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.movq_scale_factor,
+ vae_latent_channels=movq_latent_channels,
+ resample="bicubic",
+ reducing_gap=1,
+ )
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
@@ -350,11 +345,11 @@ def __call__(
be maximum and the denoising process will run for the full number of iterations specified in
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -417,7 +412,7 @@ def __call__(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
- image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
+ image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"]
@@ -498,13 +493,7 @@ def __call__(
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-
- if output_type == "pil":
- image = self.numpy_to_pil(image)
+ image = self.image_processor.postprocess(image, output_type)
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index cce5f0b3d5bc..731fce499859 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,11 +28,7 @@
from ... import __version__
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
@@ -456,11 +452,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -469,7 +465,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -496,7 +492,7 @@ def __call__(
"As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. "
"This way, Kandinsky's masking behavior is aligned with Stable Diffusion. "
"THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. "
- "This warning will be surpressed after the first inference call and will be removed in diffusers>0.23.0"
+ "This warning will be suppressed after the first inference call and will be removed in diffusers>0.23.0"
)
self._warn_has_been_called = True
@@ -579,7 +575,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
index a348deef8b29..10ea8005c90d 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -212,7 +212,7 @@ def interpolate(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
negative_prior_prompt (`str`, *optional*):
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
@@ -220,11 +220,11 @@ def interpolate(
The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
Examples:
@@ -437,13 +437,13 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
(`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
index a584674540d8..429253e99898 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -162,11 +162,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -175,7 +175,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
index 68334fef3811..fc2083247bb0 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -179,7 +179,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -242,27 +242,27 @@ def __call__(
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -407,7 +407,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
@@ -417,7 +417,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -479,11 +479,11 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
strength (`float`, *optional*, defaults to 0.3):
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
@@ -498,11 +498,11 @@ def __call__(
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -512,7 +512,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -656,7 +656,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
@@ -722,11 +722,11 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -735,11 +735,11 @@ def __call__(
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -749,7 +749,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
index bada59080c7b..c5faae82796b 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -198,11 +198,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -211,7 +211,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
index 4f6c4188bd48..54154c6ec1f2 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,11 +14,10 @@
from typing import Callable, List, Optional, Union
-import numpy as np
import PIL.Image
import torch
-from PIL import Image
+from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
@@ -105,27 +104,6 @@
"""
-# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
-def downscale_height_and_width(height, width, scale_factor=8):
- new_height = height // scale_factor**2
- if height % scale_factor**2 != 0:
- new_height += 1
- new_width = width // scale_factor**2
- if width % scale_factor**2 != 0:
- new_width += 1
- return new_height * scale_factor, new_width * scale_factor
-
-
-# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
-def prepare_image(pil_image, w=512, h=512):
- pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
- arr = np.array(pil_image.convert("RGB"))
- arr = arr.astype(np.float32) / 127.5 - 1
- arr = np.transpose(arr, [2, 0, 1])
- image = torch.from_numpy(arr).unsqueeze(0)
- return image
-
-
class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
@@ -157,7 +135,14 @@ def __init__(
scheduler=scheduler,
movq=movq,
)
- self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+ movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
+ movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=movq_scale_factor,
+ vae_latent_channels=movq_latent_channels,
+ resample="bicubic",
+ reducing_gap=1,
+ )
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
@@ -259,11 +244,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -316,7 +301,7 @@ def __call__(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
- image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
+ image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"]
@@ -324,7 +309,6 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
@@ -379,13 +363,7 @@ def __call__(
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-
- if output_type == "pil":
- image = self.numpy_to_pil(image)
+ image = self.image_processor.postprocess(image, output_type)
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
index 624748896911..3b2509098fd1 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,11 +14,10 @@
from typing import Callable, Dict, List, Optional, Union
-import numpy as np
import PIL.Image
import torch
-from PIL import Image
+from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import deprecate, is_torch_xla_available, logging
@@ -76,27 +75,6 @@
"""
-# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
-def downscale_height_and_width(height, width, scale_factor=8):
- new_height = height // scale_factor**2
- if height % scale_factor**2 != 0:
- new_height += 1
- new_width = width // scale_factor**2
- if width % scale_factor**2 != 0:
- new_width += 1
- return new_height * scale_factor, new_width * scale_factor
-
-
-# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
-def prepare_image(pil_image, w=512, h=512):
- pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
- arr = np.array(pil_image.convert("RGB"))
- arr = arr.astype(np.float32) / 127.5 - 1
- arr = np.transpose(arr, [2, 0, 1])
- image = torch.from_numpy(arr).unsqueeze(0)
- return image
-
-
class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
@@ -129,7 +107,14 @@ def __init__(
scheduler=scheduler,
movq=movq,
)
- self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
+ movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
+ movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=movq_scale_factor,
+ vae_latent_channels=movq_latent_channels,
+ resample="bicubic",
+ reducing_gap=1,
+ )
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
@@ -240,11 +225,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -319,7 +304,7 @@ def __call__(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
- image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
+ image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"]
@@ -327,7 +312,6 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
- height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
@@ -383,21 +367,9 @@ def __call__(
if XLA_AVAILABLE:
xm.mark_step()
- if output_type not in ["pt", "np", "pil", "latent"]:
- raise ValueError(
- f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
- )
-
if not output_type == "latent":
- # post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-
- if output_type == "pil":
- image = self.numpy_to_pil(image)
+ image = self.image_processor.postprocess(image, output_type)
else:
image = latents
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
index 482093a4bb29..a61673293e1f 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -343,11 +343,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -356,7 +356,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -386,7 +386,7 @@ def __call__(
"As of diffusers==0.19.0 this behavior has been inverted. Now white pixels are repainted and black pixels are preserved. "
"This way, Kandinsky's masking behavior is aligned with Stable Diffusion. "
"THIS means that you HAVE to invert the input mask to have the same behavior as before as explained in https://github.com/huggingface/diffusers/pull/4207. "
- "This warning will be surpressed after the first inference call and will be removed in diffusers>0.23.0"
+ "This warning will be suppressed after the first inference call and will be removed in diffusers>0.23.0"
)
self._warn_has_been_called = True
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
index d05a7fbdb1b8..bc67847831a5 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -6,11 +6,7 @@
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
@@ -171,7 +167,7 @@ def interpolate(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
negative_prior_prompt (`str`, *optional*):
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
@@ -179,11 +175,11 @@ def interpolate(
The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
Examples:
@@ -412,13 +408,13 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
(`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
index 56d326e26e6e..b586d166118b 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -6,11 +6,7 @@
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
@@ -195,7 +191,7 @@ def interpolate(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
negative_prior_prompt (`str`, *optional*):
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
@@ -203,11 +199,11 @@ def interpolate(
The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
Examples:
@@ -441,11 +437,11 @@ def __call__(
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
output_type (`str`, *optional*, defaults to `"pt"`):
The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
(`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
index 5309f94a53c8..57cc0270442d 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
@@ -368,11 +368,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 3.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -384,8 +384,8 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
index fbdad79db445..73c268897502 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
@@ -1,12 +1,12 @@
import inspect
from typing import Callable, Dict, List, Optional, Union
-import numpy as np
import PIL
import PIL.Image
import torch
from transformers import T5EncoderModel, T5Tokenizer
+from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
@@ -53,24 +53,6 @@
"""
-def downscale_height_and_width(height, width, scale_factor=8):
- new_height = height // scale_factor**2
- if height % scale_factor**2 != 0:
- new_height += 1
- new_width = width // scale_factor**2
- if width % scale_factor**2 != 0:
- new_width += 1
- return new_height * scale_factor, new_width * scale_factor
-
-
-def prepare_image(pil_image):
- arr = np.array(pil_image.convert("RGB"))
- arr = arr.astype(np.float32) / 127.5 - 1
- arr = np.transpose(arr, [2, 0, 1])
- image = torch.from_numpy(arr).unsqueeze(0)
- return image
-
-
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->movq->unet->movq"
_callback_tensor_inputs = [
@@ -94,6 +76,14 @@ def __init__(
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
)
+ movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
+ movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=movq_scale_factor,
+ vae_latent_channels=movq_latent_channels,
+ resample="bicubic",
+ reducing_gap=1,
+ )
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
@@ -123,7 +113,7 @@ def encode_prompt(
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
- _cut_context=False,
+ _cut_context=True,
attention_mask: Optional[torch.Tensor] = None,
negative_attention_mask: Optional[torch.Tensor] = None,
):
@@ -309,7 +299,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -449,11 +439,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 3.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -566,7 +556,7 @@ def __call__(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
- image = torch.cat([prepare_image(i) for i in image], dim=0)
+ image = torch.cat([self.image_processor.preprocess(i) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -630,20 +620,9 @@ def __call__(
xm.mark_step()
# post-processing
- if output_type not in ["pt", "np", "pil", "latent"]:
- raise ValueError(
- f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
- )
if not output_type == "latent":
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
-
- if output_type in ["np", "pil"]:
- image = image * 0.5 + 0.5
- image = image.clamp(0, 1)
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
-
- if output_type == "pil":
- image = self.numpy_to_pil(image)
+ image = self.image_processor.postprocess(image, output_type)
else:
image = latents
diff --git a/src/diffusers/pipelines/kandinsky5/__init__.py b/src/diffusers/pipelines/kandinsky5/__init__.py
new file mode 100644
index 000000000000..d417ed932b92
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/__init__.py
@@ -0,0 +1,54 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"]
+ _import_structure["pipeline_kandinsky_i2i"] = ["Kandinsky5I2IPipeline"]
+ _import_structure["pipeline_kandinsky_i2v"] = ["Kandinsky5I2VPipeline"]
+ _import_structure["pipeline_kandinsky_t2i"] = ["Kandinsky5T2IPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_kandinsky import Kandinsky5T2VPipeline
+ from .pipeline_kandinsky_i2i import Kandinsky5I2IPipeline
+ from .pipeline_kandinsky_i2v import Kandinsky5I2VPipeline
+ from .pipeline_kandinsky_t2i import Kandinsky5T2IPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py
new file mode 100644
index 000000000000..2b666f0ec697
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py
@@ -0,0 +1,970 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Callable, Dict, List, Optional, Union
+
+import regex as re
+import torch
+from torch.nn import functional as F
+from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import KandinskyLoraLoaderMixin
+from ...models import AutoencoderKLHunyuanVideo
+from ...models.transformers import Kandinsky5Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+
+# Add imports for offloading and tiling
+from ...utils import (
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import KandinskyPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import Kandinsky5T2VPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Available models:
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Pro-sft-5s-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers
+
+ >>> model_id = "kandinskylab/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
+ >>> pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen."
+ >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=512,
+ ... width=768,
+ ... num_frames=121,
+ ... num_inference_steps=50,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+
+ >>> export_to_video(output, "output.mp4", fps=24, quality=9)
+ ```
+"""
+
+
+def basic_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Clean text using ftfy if available and unescape HTML entities.
+ """
+ if is_ftfy_available():
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Normalize whitespace in text by replacing multiple spaces with single space.
+ """
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Apply both basic cleaning and whitespace normalization to prompts.
+ """
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using Kandinsky 5.0.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`Kandinsky5Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded video latents.
+ vae ([`AutoencoderKLHunyuanVideo`]):
+ Variational Auto-Encoder Model [hunyuanvideo-community/HunyuanVideo
+ (vae)](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) to encode and decode videos to and from
+ latent representations.
+ text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
+ Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct).
+ tokenizer ([`AutoProcessor`]):
+ Tokenizer for Qwen2.5-VL.
+ text_encoder_2 ([`CLIPTextModel`]):
+ Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 ([`CLIPTokenizer`]):
+ Tokenizer for CLIP.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds_qwen",
+ "prompt_embeds_clip",
+ "negative_prompt_embeds_qwen",
+ "negative_prompt_embeds_clip",
+ ]
+
+ def __init__(
+ self,
+ transformer: Kandinsky5Transformer3DModel,
+ vae: AutoencoderKLHunyuanVideo,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2VLProcessor,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ transformer=transformer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ scheduler=scheduler,
+ )
+
+ self.prompt_template = "\n".join(
+ [
+ "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
+ "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
+ "Describe the location of the video, main characters or objects and their action.",
+ "Describe the dynamism of the video and presented actions.",
+ "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
+ "Describe the visual effects, postprocessing and transitions if they are presented in the video.",
+ "Pay attention to the order of key actions shown in the scene.<|im_end|>",
+ "<|im_start|>user\n{}<|im_end|>",
+ ]
+ )
+ self.prompt_template_encode_start_idx = 129
+
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+ self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_scale_factor(self, height: int, width: int) -> tuple:
+ """
+ Calculate the scale factor based on resolution.
+
+ Args:
+ height (int): Video height
+ width (int): Video width
+
+ Returns:
+ tuple: Scale factor as (temporal_scale, height_scale, width_scale)
+ """
+
+ def between_480p(x):
+ return 480 <= x <= 854
+
+ if between_480p(height) and between_480p(width):
+ return (1, 2, 2)
+ else:
+ return (1, 3.16, 3.16)
+
+ @staticmethod
+ def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor:
+ """
+ Create a sparse temporal attention (STA) mask for efficient video generation.
+
+ This method generates a mask that limits attention to nearby frames and spatial positions, reducing
+ computational complexity for video generation.
+
+ Args:
+ T (int): Number of temporal frames
+ H (int): Height in latent space
+ W (int): Width in latent space
+ wT (int): Temporal attention window size
+ wH (int): Height attention window size
+ wW (int): Width attention window size
+ device (str): Device to create tensor on
+
+ Returns:
+ torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W)
+ """
+ l = torch.Tensor([T, H, W]).amax()
+ r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
+ mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
+ sta_t, sta_h, sta_w = (
+ mat[:T, :T].flatten(),
+ mat[:H, :H].flatten(),
+ mat[:W, :W].flatten(),
+ )
+ sta_t = sta_t <= wT // 2
+ sta_h = sta_h <= wH // 2
+ sta_w = sta_w <= wW // 2
+ sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten()
+ sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2)
+ return sta.reshape(T * H * W, T * H * W)
+
+ def get_sparse_params(self, sample, device):
+ """
+ Generate sparse attention parameters for the transformer based on sample dimensions.
+
+ This method computes the sparse attention configuration needed for efficient video processing in the
+ transformer model.
+
+ Args:
+ sample (torch.Tensor): Input sample tensor
+ device (torch.device): Device to place tensors on
+
+ Returns:
+ Dict: Dictionary containing sparse attention parameters
+ """
+ assert self.transformer.config.patch_size[0] == 1
+ B, T, H, W, _ = sample.shape
+ T, H, W = (
+ T // self.transformer.config.patch_size[0],
+ H // self.transformer.config.patch_size[1],
+ W // self.transformer.config.patch_size[2],
+ )
+ if self.transformer.config.attention_type == "nabla":
+ sta_mask = self.fast_sta_nabla(
+ T,
+ H // 8,
+ W // 8,
+ self.transformer.config.attention_wT,
+ self.transformer.config.attention_wH,
+ self.transformer.config.attention_wW,
+ device=device,
+ )
+
+ sparse_params = {
+ "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
+ "attention_type": self.transformer.config.attention_type,
+ "to_fractal": True,
+ "P": self.transformer.config.attention_P,
+ "wT": self.transformer.config.attention_wT,
+ "wW": self.transformer.config.attention_wW,
+ "wH": self.transformer.config.attention_wH,
+ "add_sta": self.transformer.config.attention_add_sta,
+ "visual_shape": (T, H, W),
+ "method": self.transformer.config.attention_method,
+ }
+ else:
+ sparse_params = None
+
+ return sparse_params
+
+ def _encode_prompt_qwen(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 256,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using Qwen2.5-VL text encoder.
+
+ This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for
+ video generation.
+
+ Args:
+ prompt (Union[str, List[str]]): Input prompt or list of prompts
+ device (torch.device): Device to run encoding on
+ num_videos_per_prompt (int): Number of videos to generate per prompt
+ max_sequence_length (int): Maximum sequence length for tokenization
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ full_texts = [self.prompt_template.format(p) for p in prompt]
+ max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length
+
+ untruncated_ids = self.tokenizer(
+ text=full_texts,
+ images=None,
+ videos=None,
+ return_tensors="pt",
+ padding="longest",
+ )["input_ids"]
+
+ if untruncated_ids.shape[-1] > max_allowed_len:
+ for i, text in enumerate(full_texts):
+ tokens = untruncated_ids[i][self.prompt_template_encode_start_idx : -2]
+ removed_text = self.tokenizer.decode(tokens[max_sequence_length - 2 :])
+ if len(removed_text) > 0:
+ full_texts[i] = text[: -len(removed_text)]
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ inputs = self.tokenizer(
+ text=full_texts,
+ images=None,
+ videos=None,
+ max_length=max_allowed_len,
+ truncation=True,
+ return_tensors="pt",
+ padding=True,
+ ).to(device)
+
+ embeds = self.text_encoder(
+ input_ids=inputs["input_ids"],
+ return_dict=True,
+ output_hidden_states=True,
+ )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :]
+
+ attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :]
+ cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32)
+
+ return embeds.to(dtype), cu_seqlens
+
+ def _encode_prompt_clip(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using CLIP text encoder.
+
+ This method processes the input prompt through the CLIP model to generate pooled embeddings that capture
+ semantic information.
+
+ Args:
+ prompt (Union[str, List[str]]): Input prompt or list of prompts
+ device (torch.device): Device to run encoding on
+ num_videos_per_prompt (int): Number of videos to generate per prompt
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ torch.Tensor: Pooled text embeddings from CLIP
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ inputs = self.tokenizer_2(
+ prompt,
+ max_length=77,
+ truncation=True,
+ add_special_tokens=True,
+ padding="max_length",
+ return_tensors="pt",
+ ).to(device)
+
+ pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
+
+ return pooled_embed.to(dtype)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes a single prompt (positive or negative) into text encoder hidden states.
+
+ This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text
+ representations for video generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ Prompt to be encoded.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos to generate per prompt.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ Maximum sequence length for text encoding.
+ device (`torch.device`, *optional*):
+ Torch device.
+ dtype (`torch.dtype`, *optional*):
+ Torch dtype.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim)
+ - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim)
+ - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size *
+ num_videos_per_prompt + 1,)
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ if not isinstance(prompt, list):
+ prompt = [prompt]
+
+ batch_size = len(prompt)
+
+ prompt = [prompt_clean(p) for p in prompt]
+
+ # Encode with Qwen2.5-VL
+ prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ dtype=dtype,
+ )
+ # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
+
+ # Encode with CLIP
+ prompt_embeds_clip = self._encode_prompt_clip(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ )
+ # prompt_embeds_clip shape: [batch_size, clip_embed_dim]
+
+ # Repeat embeddings for num_videos_per_prompt
+ # Qwen embeddings: repeat sequence for each video, then reshape
+ prompt_embeds_qwen = prompt_embeds_qwen.repeat(
+ 1, num_videos_per_prompt, 1
+ ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim]
+ # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim]
+ prompt_embeds_qwen = prompt_embeds_qwen.view(
+ batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1]
+ )
+
+ # CLIP embeddings: repeat for each video
+ prompt_embeds_clip = prompt_embeds_clip.repeat(
+ 1, num_videos_per_prompt, 1
+ ) # [batch_size, num_videos_per_prompt, clip_embed_dim]
+ # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim]
+ prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1)
+
+ # Repeat cumulative sequence lengths for num_videos_per_prompt
+ # Original cu_seqlens: [0, len1, len1+len2, ...]
+ # Need to repeat the differences and reconstruct for repeated prompts
+ # Original differences (lengths) for each prompt in the batch
+ original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...]
+ # Repeat the lengths for num_videos_per_prompt
+ repeated_lengths = original_lengths.repeat_interleave(
+ num_videos_per_prompt
+ ) # [len1, len1, ..., len2, len2, ...]
+ # Reconstruct the cumulative lengths
+ repeated_cu_seqlens = torch.cat(
+ [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]
+ )
+
+ return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds_qwen=None,
+ prompt_embeds_clip=None,
+ negative_prompt_embeds_qwen=None,
+ negative_prompt_embeds_clip=None,
+ prompt_cu_seqlens=None,
+ negative_prompt_cu_seqlens=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ """
+ Validate input parameters for the pipeline.
+
+ Args:
+ prompt: Input prompt
+ negative_prompt: Negative prompt for guidance
+ height: Video height
+ width: Video width
+ prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
+ prompt_embeds_clip: Pre-computed CLIP prompt embeddings
+ negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
+ negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
+ prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
+ negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
+ callback_on_step_end_tensor_inputs: Callback tensor inputs
+
+ Raises:
+ ValueError: If inputs are invalid
+ """
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError("max_sequence_length must be less than 1024")
+
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ # Check for consistency within positive prompt embeddings and sequence lengths
+ if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None:
+ if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None:
+ raise ValueError(
+ "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check for consistency within negative prompt embeddings and sequence lengths
+ if (
+ negative_prompt_embeds_qwen is not None
+ or negative_prompt_embeds_clip is not None
+ or negative_prompt_cu_seqlens is not None
+ ):
+ if (
+ negative_prompt_embeds_qwen is None
+ or negative_prompt_embeds_clip is None
+ or negative_prompt_cu_seqlens is None
+ ):
+ raise ValueError(
+ "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive)
+ if prompt is None and prompt_embeds_qwen is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined."
+ )
+
+ # Validate types for prompt and negative_prompt if provided
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Prepare initial latent variables for video generation.
+
+ This method creates random noise latents or uses provided latents as starting point for the denoising process.
+
+ Args:
+ batch_size (int): Number of videos to generate
+ num_channels_latents (int): Number of channels in latent space
+ height (int): Height of generated video
+ width (int): Width of generated video
+ num_frames (int): Number of frames in video
+ dtype (torch.dtype): Data type for latents
+ device (torch.device): Device to create latents on
+ generator (torch.Generator): Random number generator
+ latents (torch.Tensor): Pre-existing latents to use
+
+ Returns:
+ torch.Tensor: Prepared latent tensor
+ """
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ if self.transformer.visual_cond:
+ # For visual conditioning, concatenate with zeros and mask
+ visual_cond = torch.zeros_like(latents)
+ visual_cond_mask = torch.zeros(
+ [
+ batch_size,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ 1,
+ ],
+ dtype=latents.dtype,
+ device=latents.device,
+ )
+ latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ """Get the current guidance scale value."""
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ """Get the number of denoising timesteps."""
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ """Check if generation has been interrupted."""
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 768,
+ num_frames: int = 121,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ prompt_embeds_clip: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_clip: Optional[torch.Tensor] = None,
+ prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ negative_prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `512`):
+ The height in pixels of the generated video.
+ width (`int`, defaults to `768`):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `25`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in classifier-free guidance.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A torch generator to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`KandinskyPipelineOutput`].
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function that is called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length for text encoding.
+
+ Examples:
+
+ Returns:
+ [`~KandinskyPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images.
+ """
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ prompt_embeds_qwen=prompt_embeds_qwen,
+ prompt_embeds_clip=prompt_embeds_clip,
+ negative_prompt_embeds_qwen=negative_prompt_embeds_qwen,
+ negative_prompt_embeds_clip=negative_prompt_embeds_clip,
+ prompt_cu_seqlens=prompt_cu_seqlens,
+ negative_prompt_cu_seqlens=negative_prompt_cu_seqlens,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds_qwen.shape[0]
+
+ # 3. Encode input prompt
+ if prompt_embeds_qwen is None:
+ prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
+ prompt=prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if self.guidance_scale > 1.0:
+ if negative_prompt is None:
+ negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
+ elif len(negative_prompt) != len(prompt):
+ raise ValueError(
+ f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
+ )
+
+ if negative_prompt_embeds_qwen is None:
+ negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
+ self.encode_prompt(
+ prompt=negative_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_visual_dim
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare rope positions for positional encoding
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ visual_rope_pos = [
+ torch.arange(num_latent_frames, device=device),
+ torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
+ torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
+ ]
+
+ text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
+
+ negative_text_rope_pos = (
+ torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
+ if negative_prompt_cu_seqlens is not None
+ else None
+ )
+
+ # 7. Calculate dynamic scale factor based on resolution
+ scale_factor = self._get_scale_factor(height, width)
+
+ # 8. Sparse Params for efficient attention
+ sparse_params = self.get_sparse_params(latents, device)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt)
+
+ # Predict noise residual
+ pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=prompt_embeds_qwen.to(dtype),
+ pooled_projections=prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=text_rope_pos,
+ scale_factor=scale_factor,
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None:
+ uncond_pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype),
+ pooled_projections=negative_prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=negative_text_rope_pos,
+ scale_factor=scale_factor,
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity)
+ # Compute previous sample using the scheduler
+ latents[:, :, :, :, :num_channels_latents] = self.scheduler.step(
+ pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen)
+ prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip)
+ negative_prompt_embeds_qwen = callback_outputs.pop(
+ "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen
+ )
+ negative_prompt_embeds_clip = callback_outputs.pop(
+ "negative_prompt_embeds_clip", negative_prompt_embeds_clip
+ )
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # 10. Post-processing - extract main latents
+ latents = latents[:, :, :, :, :num_channels_latents]
+
+ # 11. Decode latents to video
+ if output_type != "latent":
+ latents = latents.to(self.vae.dtype)
+ # Reshape and normalize latents
+ video = latents.reshape(
+ batch_size,
+ num_videos_per_prompt,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+ video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
+ video = video.reshape(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ # Normalize and decode through VAE
+ video = video / self.vae.config.scaling_factor
+ video = self.vae.decode(video).sample
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return KandinskyPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py
new file mode 100644
index 000000000000..f965cdad8f3e
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py
@@ -0,0 +1,863 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import regex as re
+import torch
+from torch.nn import functional as F
+from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import KandinskyLoraLoaderMixin
+from ...models import AutoencoderKL
+from ...models.transformers import Kandinsky5Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+
+# Add imports for offloading and tiling
+from ...utils import (
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import KandinskyImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import Kandinsky5I2IPipeline
+
+ >>> # Available models:
+ >>> # kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-I2I-Lite-pretrain-Diffusers
+
+ >>> model_id = "kandinskylab/Kandinsky-5.0-I2I-Lite-sft-Diffusers"
+ >>> pipe = Kandinsky5I2IPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt="",
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=50,
+ ... guidance_scale=3.5,
+ ... ).frames[0]
+ ```
+"""
+
+
+def basic_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Clean text using ftfy if available and unescape HTML entities.
+ """
+ if is_ftfy_available():
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Normalize whitespace in text by replacing multiple spaces with single space.
+ """
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Apply both basic cleaning and whitespace normalization to prompts.
+ """
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class Kandinsky5I2IPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-image generation using Kandinsky 5.0.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`Kandinsky5Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder Model [black-forest-labs/FLUX.1-dev
+ (vae)](https://huggingface.co/black-forest-labs/FLUX.1-dev) to encode and decode videos to and from latent
+ representations.
+ text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
+ Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct).
+ tokenizer ([`AutoProcessor`]):
+ Tokenizer for Qwen2.5-VL.
+ text_encoder_2 ([`CLIPTextModel`]):
+ Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 ([`CLIPTokenizer`]):
+ Tokenizer for CLIP.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds_qwen",
+ "prompt_embeds_clip",
+ "negative_prompt_embeds_qwen",
+ "negative_prompt_embeds_clip",
+ ]
+
+ def __init__(
+ self,
+ transformer: Kandinsky5Transformer3DModel,
+ vae: AutoencoderKL,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2VLProcessor,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ transformer=transformer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ scheduler=scheduler,
+ )
+ self.prompt_template = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|><|im_start|>user{}<|vision_start|><|image_pad|><|vision_end|><|im_end|>"
+ self.prompt_template_encode_start_idx = 55
+
+ self.vae_scale_factor_spatial = 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.resolutions = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)]
+
+ def _encode_prompt_qwen(
+ self,
+ prompt: List[str],
+ image: Optional[PipelineImageInput] = None,
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 1024,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using Qwen2.5-VL text encoder.
+
+ This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for
+ image generation.
+
+ Args:
+ prompt List[str]: Input list of prompts
+ image (PipelineImageInput): Input list of images to condition the generation on
+ device (torch.device): Device to run encoding on
+ max_sequence_length (int): Maximum sequence length for tokenization
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+ if not isinstance(image, list):
+ image = [image]
+ image = [i.resize((i.size[0] // 2, i.size[1] // 2)) for i in image]
+ full_texts = [self.prompt_template.format(p) for p in prompt]
+ max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length
+
+ untruncated_ids = self.tokenizer(
+ text=full_texts,
+ images=image,
+ videos=None,
+ return_tensors="pt",
+ padding="longest",
+ )["input_ids"]
+
+ if untruncated_ids.shape[-1] > max_allowed_len:
+ for i, text in enumerate(full_texts):
+ tokens = untruncated_ids[i]
+ num_image_tokens = (tokens == self.tokenizer.image_token_id).sum()
+ tokens = tokens[tokens != self.tokenizer.image_token_id][self.prompt_template_encode_start_idx : -3]
+ removed_text = self.tokenizer.decode(tokens[max_sequence_length - num_image_tokens - 3 :])
+ if len(removed_text) > 0:
+ full_texts[i] = text[: -len(removed_text)]
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ inputs = self.tokenizer(
+ text=full_texts,
+ images=image,
+ videos=None,
+ max_length=max_allowed_len,
+ truncation=True,
+ return_tensors="pt",
+ padding=True,
+ ).to(device)
+
+ embeds = self.text_encoder(
+ **inputs,
+ return_dict=True,
+ output_hidden_states=True,
+ )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :]
+
+ attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :]
+ cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32)
+
+ return embeds.to(dtype), cu_seqlens
+
+ def _encode_prompt_clip(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using CLIP text encoder.
+
+ This method processes the input prompt through the CLIP model to generate pooled embeddings that capture
+ semantic information.
+
+ Args:
+ prompt (Union[str, List[str]]): Input prompt or list of prompts
+ device (torch.device): Device to run encoding on
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ torch.Tensor: Pooled text embeddings from CLIP
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ inputs = self.tokenizer_2(
+ prompt,
+ max_length=77,
+ truncation=True,
+ add_special_tokens=True,
+ padding="max_length",
+ return_tensors="pt",
+ ).to(device)
+
+ pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
+
+ return pooled_embed.to(dtype)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ image: torch.Tensor,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 1024,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes a single prompt (positive or negative) into text encoder hidden states.
+
+ This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text
+ representations for image generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ Prompt to be encoded.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images to generate per prompt.
+ max_sequence_length (`int`, *optional*, defaults to 1024):
+ Maximum sequence length for text encoding. Must be less than 1024
+ device (`torch.device`, *optional*):
+ Torch device.
+ dtype (`torch.dtype`, *optional*):
+ Torch dtype.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ - Qwen text embeddings of shape (batch_size * num_images_per_prompt, sequence_length, embedding_dim)
+ - CLIP pooled embeddings of shape (batch_size * num_images_per_prompt, clip_embedding_dim)
+ - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size *
+ num_images_per_prompt + 1,)
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ if not isinstance(prompt, list):
+ prompt = [prompt]
+
+ batch_size = len(prompt)
+
+ prompt = [prompt_clean(p) for p in prompt]
+
+ # Encode with Qwen2.5-VL
+ prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
+ prompt=prompt,
+ image=image,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ dtype=dtype,
+ )
+ # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
+
+ # Encode with CLIP
+ prompt_embeds_clip = self._encode_prompt_clip(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ )
+ # prompt_embeds_clip shape: [batch_size, clip_embed_dim]
+
+ # Repeat embeddings for num_images_per_prompt
+ # Qwen embeddings: repeat sequence for each image, then reshape
+ prompt_embeds_qwen = prompt_embeds_qwen.repeat(
+ 1, num_images_per_prompt, 1
+ ) # [batch_size, seq_len * num_images_per_prompt, embed_dim]
+ # Reshape to [batch_size * num_images_per_prompt, seq_len, embed_dim]
+ prompt_embeds_qwen = prompt_embeds_qwen.view(
+ batch_size * num_images_per_prompt, -1, prompt_embeds_qwen.shape[-1]
+ )
+
+ # CLIP embeddings: repeat for each image
+ prompt_embeds_clip = prompt_embeds_clip.repeat(
+ 1, num_images_per_prompt, 1
+ ) # [batch_size, num_images_per_prompt, clip_embed_dim]
+ # Reshape to [batch_size * num_images_per_prompt, clip_embed_dim]
+ prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_images_per_prompt, -1)
+
+ # Repeat cumulative sequence lengths for num_images_per_prompt
+ # Original differences (lengths) for each prompt in the batch
+ original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...]
+ # Repeat the lengths for num_images_per_prompt
+ repeated_lengths = original_lengths.repeat_interleave(
+ num_images_per_prompt
+ ) # [len1, len1, ..., len2, len2, ...]
+ # Reconstruct the cumulative lengths
+ repeated_cu_seqlens = torch.cat(
+ [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]
+ )
+
+ return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds_qwen=None,
+ prompt_embeds_clip=None,
+ negative_prompt_embeds_qwen=None,
+ negative_prompt_embeds_clip=None,
+ prompt_cu_seqlens=None,
+ negative_prompt_cu_seqlens=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ """
+ Validate input parameters for the pipeline.
+
+ Args:
+ prompt: Input prompt
+ negative_prompt: Negative prompt for guidance
+ image: Input image for conditioning
+ height: Image height
+ width: Image width
+ prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
+ prompt_embeds_clip: Pre-computed CLIP prompt embeddings
+ negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
+ negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
+ prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
+ negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
+ callback_on_step_end_tensor_inputs: Callback tensor inputs
+
+ Raises:
+ ValueError: If inputs are invalid
+ """
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError("max_sequence_length must be less than 1024")
+
+ if image is None:
+ raise ValueError("`image` must be provided for image-to-image generation")
+
+ if (width, height) not in self.resolutions:
+ resolutions_str = ",".join([f"({w},{h})" for w, h in self.resolutions])
+ logger.warning(
+ f"`height` and `width` have to be one of {resolutions_str}, but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ # Check for consistency within positive prompt embeddings and sequence lengths
+ if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None:
+ if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None:
+ raise ValueError(
+ "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check for consistency within negative prompt embeddings and sequence lengths
+ if (
+ negative_prompt_embeds_qwen is not None
+ or negative_prompt_embeds_clip is not None
+ or negative_prompt_cu_seqlens is not None
+ ):
+ if (
+ negative_prompt_embeds_qwen is None
+ or negative_prompt_embeds_clip is None
+ or negative_prompt_cu_seqlens is None
+ ):
+ raise ValueError(
+ "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive)
+ if prompt is None and prompt_embeds_qwen is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined."
+ )
+
+ # Validate types for prompt and negative_prompt if provided
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 1024,
+ width: int = 1024,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Prepare initial latent variables for image-to-image generation.
+
+ This method creates random noise latents with encoded image,
+
+ Args:
+ image (PipelineImageInput): Input image to condition the generation on
+ batch_size (int): Number of images to generate
+ num_channels_latents (int): Number of channels in latent space
+ height (int): Height of generated image
+ width (int): Width of generated image
+ dtype (torch.dtype): Data type for latents
+ device (torch.device): Device to create latents on
+ generator (torch.Generator): Random number generator
+ latents (torch.Tensor): Pre-existing latents to use
+
+ Returns:
+ torch.Tensor: Prepared latent tensor with encoded image
+ """
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ 1,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # Generate random noise for all frames
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # Encode the input image to use as first frame
+ # Preprocess image
+ image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
+ # Encode image to latents using VAE
+ with torch.no_grad():
+ image_latents = self.vae.encode(image_tensor).latent_dist.sample(generator=generator)
+ image_latents = image_latents.unsqueeze(2) # Add temporal dimension
+
+ # Normalize latents if needed
+ if hasattr(self.vae.config, "scaling_factor"):
+ image_latents = image_latents * self.vae.config.scaling_factor
+
+ # Reshape to match latent dimensions [batch, 1, height, width, channels]
+ image_latents = image_latents.permute(0, 2, 3, 4, 1) # [batch, 1, H, W, C]
+ latents = torch.cat([latents, image_latents, torch.ones_like(latents[..., :1])], -1)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ """Get the current guidance scale value."""
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ """Get the number of denoising timesteps."""
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ """Check if generation has been interrupted."""
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ prompt_embeds_clip: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_clip: Optional[torch.Tensor] = None,
+ prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ negative_prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+ The call function to the pipeline for image-to-image generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`):
+ The height in pixels of the generated image.
+ width (`int`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in classifier-free guidance.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A torch generator to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds_qwen (`torch.Tensor`, *optional*):
+ Pre-generated Qwen text embeddings.
+ prompt_embeds_clip (`torch.Tensor`, *optional*):
+ Pre-generated CLIP text embeddings.
+ negative_prompt_embeds_qwen (`torch.Tensor`, *optional*):
+ Pre-generated Qwen negative text embeddings.
+ negative_prompt_embeds_clip (`torch.Tensor`, *optional*):
+ Pre-generated CLIP negative text embeddings.
+ prompt_cu_seqlens (`torch.Tensor`, *optional*):
+ Pre-generated cumulative sequence lengths for Qwen positive prompt.
+ negative_prompt_cu_seqlens (`torch.Tensor`, *optional*):
+ Pre-generated cumulative sequence lengths for Qwen negative prompt.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`KandinskyImagePipelineOutput`].
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function that is called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `1024`):
+ The maximum sequence length for text and image qwen encoding. Must be less than 1024
+
+ Examples:
+
+ Returns:
+ [`~KandinskyImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`KandinskyImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ # 1. Check inputs. Raise error if not correct
+ if height is None and width is None:
+ width, height = image[0].size if isinstance(image, list) else image.size
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ height=height,
+ width=width,
+ prompt_embeds_qwen=prompt_embeds_qwen,
+ prompt_embeds_clip=prompt_embeds_clip,
+ negative_prompt_embeds_qwen=negative_prompt_embeds_qwen,
+ negative_prompt_embeds_clip=negative_prompt_embeds_clip,
+ prompt_cu_seqlens=prompt_cu_seqlens,
+ negative_prompt_cu_seqlens=negative_prompt_cu_seqlens,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+ if (width, height) not in self.resolutions:
+ width, height = self.resolutions[
+ np.argmin([abs((i[0] / i[1]) - (width / height)) for i in self.resolutions])
+ ]
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds_qwen.shape[0]
+
+ # 3. Encode input prompt
+ if prompt_embeds_qwen is None:
+ prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
+ prompt=prompt,
+ image=image,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if self.guidance_scale > 1.0:
+ if negative_prompt is None:
+ negative_prompt = ""
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
+ elif len(negative_prompt) != len(prompt):
+ raise ValueError(
+ f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
+ )
+
+ if negative_prompt_embeds_qwen is None:
+ negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
+ self.encode_prompt(
+ prompt=negative_prompt,
+ image=image,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables with image conditioning
+ num_channels_latents = self.transformer.config.in_visual_dim
+ latents = self.prepare_latents(
+ image=image,
+ batch_size=batch_size * num_images_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 6. Prepare rope positions for positional encoding
+ visual_rope_pos = [
+ torch.arange(1, device=device),
+ torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
+ torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
+ ]
+
+ text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
+
+ negative_text_rope_pos = (
+ torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
+ if negative_prompt_cu_seqlens is not None
+ else None
+ )
+
+ # 7. Calculate dynamic scale factor based on resolution
+ scale_factor = [1.0, 1.0, 1.0]
+
+ # 8. Sparse Params for efficient attention
+ sparse_params = None
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ timestep = t.unsqueeze(0).repeat(batch_size * num_images_per_prompt)
+
+ # Predict noise residual
+ pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=prompt_embeds_qwen.to(dtype),
+ pooled_projections=prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=text_rope_pos,
+ scale_factor=scale_factor,
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None:
+ uncond_pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype),
+ pooled_projections=negative_prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=negative_text_rope_pos,
+ scale_factor=scale_factor,
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity)
+
+ latents[:, :, :, :, :num_channels_latents] = self.scheduler.step(
+ pred_velocity[:, :], t, latents[:, :, :, :, :num_channels_latents], return_dict=False
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen)
+ prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip)
+ negative_prompt_embeds_qwen = callback_outputs.pop(
+ "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen
+ )
+ negative_prompt_embeds_clip = callback_outputs.pop(
+ "negative_prompt_embeds_clip", negative_prompt_embeds_clip
+ )
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # 9. Post-processing - extract main latents
+ latents = latents[:, :, :, :, :num_channels_latents]
+
+ # 10. Decode latents to image
+ if output_type != "latent":
+ latents = latents.to(self.vae.dtype)
+ # Reshape and normalize latents
+ latents = latents.reshape(
+ batch_size,
+ num_images_per_prompt,
+ 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+ latents = latents.permute(0, 1, 5, 2, 3, 4) # [batch, num_images, channels, 1, height, width]
+ latents = latents.reshape(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ # Normalize and decode through VAE
+ latents = latents / self.vae.config.scaling_factor
+ image = self.vae.decode(latents).sample
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return KandinskyImagePipelineOutput(image=image)
diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py
new file mode 100644
index 000000000000..d457c9b69657
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py
@@ -0,0 +1,1054 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Callable, Dict, List, Optional, Union
+
+import regex as re
+import torch
+from torch.nn import functional as F
+from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import KandinskyLoraLoaderMixin
+from ...models import AutoencoderKLHunyuanVideo
+from ...models.transformers import Kandinsky5Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+
+# Add imports for offloading and tiling
+from ...utils import (
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import KandinskyPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import Kandinsky5I2VPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> # Available models:
+ >>> # kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers
+
+ >>> model_id = "kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers"
+ >>> pipe = Kandinsky5I2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+ >>> prompt = "An astronaut floating in space with Earth in the background, cinematic shot"
+ >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=512,
+ ... width=768,
+ ... num_frames=121,
+ ... num_inference_steps=50,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+
+ >>> export_to_video(output, "output.mp4", fps=24, quality=9)
+ ```
+"""
+
+
+def basic_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Clean text using ftfy if available and unescape HTML entities.
+ """
+ if is_ftfy_available():
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Normalize whitespace in text by replacing multiple spaces with single space.
+ """
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Apply both basic cleaning and whitespace normalization to prompts.
+ """
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class Kandinsky5I2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
+ r"""
+ Pipeline for image-to-video generation using Kandinsky 5.0.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`Kandinsky5Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded video latents.
+ vae ([`AutoencoderKLHunyuanVideo`]):
+ Variational Auto-Encoder Model [hunyuanvideo-community/HunyuanVideo
+ (vae)](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) to encode and decode videos to and from
+ latent representations.
+ text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
+ Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct).
+ tokenizer ([`AutoProcessor`]):
+ Tokenizer for Qwen2.5-VL.
+ text_encoder_2 ([`CLIPTextModel`]):
+ Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 ([`CLIPTokenizer`]):
+ Tokenizer for CLIP.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds_qwen",
+ "prompt_embeds_clip",
+ "negative_prompt_embeds_qwen",
+ "negative_prompt_embeds_clip",
+ ]
+
+ def __init__(
+ self,
+ transformer: Kandinsky5Transformer3DModel,
+ vae: AutoencoderKLHunyuanVideo,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2VLProcessor,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ transformer=transformer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ scheduler=scheduler,
+ )
+
+ self.prompt_template = "\n".join(
+ [
+ "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
+ "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
+ "Describe the location of the video, main characters or objects and their action.",
+ "Describe the dynamism of the video and presented actions.",
+ "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
+ "Describe the visual effects, postprocessing and transitions if they are presented in the video.",
+ "Pay attention to the order of key actions shown in the scene.<|im_end|>",
+ "<|im_start|>user\n{}<|im_end|>",
+ ]
+ )
+ self.prompt_template_encode_start_idx = 129
+
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+ self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_scale_factor(self, height: int, width: int) -> tuple:
+ """
+ Calculate the scale factor based on resolution.
+
+ Args:
+ height (int): Video height
+ width (int): Video width
+
+ Returns:
+ tuple: Scale factor as (temporal_scale, height_scale, width_scale)
+ """
+
+ def between_480p(x):
+ return 480 <= x <= 854
+
+ if between_480p(height) and between_480p(width):
+ return (1, 2, 2)
+ else:
+ return (1, 3.16, 3.16)
+
+ @staticmethod
+ def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor:
+ """
+ Create a sparse temporal attention (STA) mask for efficient video generation.
+
+ This method generates a mask that limits attention to nearby frames and spatial positions, reducing
+ computational complexity for video generation.
+
+ Args:
+ T (int): Number of temporal frames
+ H (int): Height in latent space
+ W (int): Width in latent space
+ wT (int): Temporal attention window size
+ wH (int): Height attention window size
+ wW (int): Width attention window size
+ device (str): Device to create tensor on
+
+ Returns:
+ torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W)
+ """
+ l = torch.Tensor([T, H, W]).amax()
+ r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
+ mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
+ sta_t, sta_h, sta_w = (
+ mat[:T, :T].flatten(),
+ mat[:H, :H].flatten(),
+ mat[:W, :W].flatten(),
+ )
+ sta_t = sta_t <= wT // 2
+ sta_h = sta_h <= wH // 2
+ sta_w = sta_w <= wW // 2
+ sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten()
+ sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2)
+ return sta.reshape(T * H * W, T * H * W)
+
+ def get_sparse_params(self, sample, device):
+ """
+ Generate sparse attention parameters for the transformer based on sample dimensions.
+
+ This method computes the sparse attention configuration needed for efficient video processing in the
+ transformer model.
+
+ Args:
+ sample (torch.Tensor): Input sample tensor
+ device (torch.device): Device to place tensors on
+
+ Returns:
+ Dict: Dictionary containing sparse attention parameters
+ """
+ assert self.transformer.config.patch_size[0] == 1
+ B, T, H, W, _ = sample.shape
+ T, H, W = (
+ T // self.transformer.config.patch_size[0],
+ H // self.transformer.config.patch_size[1],
+ W // self.transformer.config.patch_size[2],
+ )
+ if self.transformer.config.attention_type == "nabla":
+ sta_mask = self.fast_sta_nabla(
+ T,
+ H // 8,
+ W // 8,
+ self.transformer.config.attention_wT,
+ self.transformer.config.attention_wH,
+ self.transformer.config.attention_wW,
+ device=device,
+ )
+
+ sparse_params = {
+ "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
+ "attention_type": self.transformer.config.attention_type,
+ "to_fractal": True,
+ "P": self.transformer.config.attention_P,
+ "wT": self.transformer.config.attention_wT,
+ "wW": self.transformer.config.attention_wW,
+ "wH": self.transformer.config.attention_wH,
+ "add_sta": self.transformer.config.attention_add_sta,
+ "visual_shape": (T, H, W),
+ "method": self.transformer.config.attention_method,
+ }
+ else:
+ sparse_params = None
+
+ return sparse_params
+
+ def _encode_prompt_qwen(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 256,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using Qwen2.5-VL text encoder.
+
+ This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for
+ video generation.
+
+ Args:
+ prompt (Union[str, List[str]]): Input prompt or list of prompts
+ device (torch.device): Device to run encoding on
+ max_sequence_length (int): Maximum sequence length for tokenization
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ full_texts = [self.prompt_template.format(p) for p in prompt]
+ max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length
+
+ untruncated_ids = self.tokenizer(
+ text=full_texts,
+ images=None,
+ videos=None,
+ return_tensors="pt",
+ padding="longest",
+ )["input_ids"]
+
+ if untruncated_ids.shape[-1] > max_allowed_len:
+ for i, text in enumerate(full_texts):
+ tokens = untruncated_ids[i][self.prompt_template_encode_start_idx : -2]
+ removed_text = self.tokenizer.decode(tokens[max_sequence_length - 2 :])
+ if len(removed_text) > 0:
+ full_texts[i] = text[: -len(removed_text)]
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+ inputs = self.tokenizer(
+ text=full_texts,
+ images=None,
+ videos=None,
+ max_length=max_allowed_len,
+ truncation=True,
+ return_tensors="pt",
+ padding=True,
+ ).to(device)
+
+ embeds = self.text_encoder(
+ input_ids=inputs["input_ids"],
+ return_dict=True,
+ output_hidden_states=True,
+ )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :]
+
+ attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :]
+ cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32)
+
+ return embeds.to(dtype), cu_seqlens
+
+ def _encode_prompt_clip(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using CLIP text encoder.
+
+ This method processes the input prompt through the CLIP model to generate pooled embeddings that capture
+ semantic information.
+
+ Args:
+ prompt (Union[str, List[str]]): Input prompt or list of prompts
+ device (torch.device): Device to run encoding on
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ torch.Tensor: Pooled text embeddings from CLIP
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ inputs = self.tokenizer_2(
+ prompt,
+ max_length=77,
+ truncation=True,
+ add_special_tokens=True,
+ padding="max_length",
+ return_tensors="pt",
+ ).to(device)
+
+ pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
+
+ return pooled_embed.to(dtype)
+
+ @staticmethod
+ def adaptive_mean_std_normalization(source, reference):
+ source_mean = source.mean(dim=(1, 2, 3, 4), keepdim=True)
+ source_std = source.std(dim=(1, 2, 3, 4), keepdim=True)
+ # magic constants - limit changes in latents
+ clump_mean_low = 0.05
+ clump_mean_high = 0.1
+ clump_std_low = 0.1
+ clump_std_high = 0.25
+
+ reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
+ reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
+
+ # normalization
+ normalized = (source - source_mean) / source_std
+ normalized = normalized * reference_std + reference_mean
+
+ return normalized
+
+ def normalize_first_frame(self, latents, reference_frames=5, clump_values=False):
+ latents_copy = latents.clone()
+ samples = latents_copy
+
+ if samples.shape[1] <= 1:
+ return (latents, "Only one frame, no normalization needed")
+
+ nFr = 4
+ first_frames = samples.clone()[:, :nFr]
+ reference_frames_data = samples[:, nFr : nFr + min(reference_frames, samples.shape[1] - 1)]
+
+ normalized_first = self.adaptive_mean_std_normalization(first_frames, reference_frames_data)
+ if clump_values:
+ min_val = reference_frames_data.min()
+ max_val = reference_frames_data.max()
+ normalized_first = torch.clamp(normalized_first, min_val, max_val)
+
+ samples[:, :nFr] = normalized_first
+
+ return samples
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes a single prompt (positive or negative) into text encoder hidden states.
+
+ This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text
+ representations for video generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ Prompt to be encoded.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos to generate per prompt.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ Maximum sequence length for text encoding.
+ device (`torch.device`, *optional*):
+ Torch device.
+ dtype (`torch.dtype`, *optional*):
+ Torch dtype.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim)
+ - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim)
+ - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size *
+ num_videos_per_prompt + 1,)
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ if not isinstance(prompt, list):
+ prompt = [prompt]
+
+ batch_size = len(prompt)
+
+ prompt = [prompt_clean(p) for p in prompt]
+
+ # Encode with Qwen2.5-VL
+ prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ dtype=dtype,
+ )
+ # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
+
+ # Encode with CLIP
+ prompt_embeds_clip = self._encode_prompt_clip(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ )
+ # prompt_embeds_clip shape: [batch_size, clip_embed_dim]
+
+ # Repeat embeddings for num_videos_per_prompt
+ # Qwen embeddings: repeat sequence for each video, then reshape
+ prompt_embeds_qwen = prompt_embeds_qwen.repeat(
+ 1, num_videos_per_prompt, 1
+ ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim]
+ # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim]
+ prompt_embeds_qwen = prompt_embeds_qwen.view(
+ batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1]
+ )
+
+ # CLIP embeddings: repeat for each video
+ prompt_embeds_clip = prompt_embeds_clip.repeat(
+ 1, num_videos_per_prompt, 1
+ ) # [batch_size, num_videos_per_prompt, clip_embed_dim]
+ # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim]
+ prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1)
+
+ # Repeat cumulative sequence lengths for num_videos_per_prompt
+ # Original differences (lengths) for each prompt in the batch
+ original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...]
+ # Repeat the lengths for num_videos_per_prompt
+ repeated_lengths = original_lengths.repeat_interleave(
+ num_videos_per_prompt
+ ) # [len1, len1, ..., len2, len2, ...]
+ # Reconstruct the cumulative lengths
+ repeated_cu_seqlens = torch.cat(
+ [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]
+ )
+
+ return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds_qwen=None,
+ prompt_embeds_clip=None,
+ negative_prompt_embeds_qwen=None,
+ negative_prompt_embeds_clip=None,
+ prompt_cu_seqlens=None,
+ negative_prompt_cu_seqlens=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ """
+ Validate input parameters for the pipeline.
+
+ Args:
+ prompt: Input prompt
+ negative_prompt: Negative prompt for guidance
+ image: Input image for conditioning
+ height: Video height
+ width: Video width
+ prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
+ prompt_embeds_clip: Pre-computed CLIP prompt embeddings
+ negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
+ negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
+ prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
+ negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
+ callback_on_step_end_tensor_inputs: Callback tensor inputs
+
+ Raises:
+ ValueError: If inputs are invalid
+ """
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError("max_sequence_length must be less than 1024")
+
+ if image is None:
+ raise ValueError("`image` must be provided for image-to-video generation")
+
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ # Check for consistency within positive prompt embeddings and sequence lengths
+ if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None:
+ if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None:
+ raise ValueError(
+ "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check for consistency within negative prompt embeddings and sequence lengths
+ if (
+ negative_prompt_embeds_qwen is not None
+ or negative_prompt_embeds_clip is not None
+ or negative_prompt_cu_seqlens is not None
+ ):
+ if (
+ negative_prompt_embeds_qwen is None
+ or negative_prompt_embeds_clip is None
+ or negative_prompt_cu_seqlens is None
+ ):
+ raise ValueError(
+ "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive)
+ if prompt is None and prompt_embeds_qwen is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined."
+ )
+
+ # Validate types for prompt and negative_prompt if provided
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Prepare initial latent variables for image-to-video generation.
+
+ This method creates random noise latents for all frames except the first frame, which is replaced with the
+ encoded input image.
+
+ Args:
+ image (PipelineImageInput): Input image to condition the generation on
+ batch_size (int): Number of videos to generate
+ num_channels_latents (int): Number of channels in latent space
+ height (int): Height of generated video
+ width (int): Width of generated video
+ num_frames (int): Number of frames in video
+ dtype (torch.dtype): Data type for latents
+ device (torch.device): Device to create latents on
+ generator (torch.Generator): Random number generator
+ latents (torch.Tensor): Pre-existing latents to use
+
+ Returns:
+ torch.Tensor: Prepared latent tensor with first frame as encoded image
+ """
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # Generate random noise for all frames
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ # Encode the input image to use as first frame
+ # Preprocess image
+ image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=dtype)
+
+ # Encode image to latents using VAE
+ with torch.no_grad():
+ # Convert image to video format [batch, channels, 1, height, width]
+ image_video = image_tensor.unsqueeze(2) # Add temporal dimension
+ image_latents = self.vae.encode(image_video).latent_dist.sample(generator=generator)
+
+ # Normalize latents if needed
+ if hasattr(self.vae.config, "scaling_factor"):
+ image_latents = image_latents * self.vae.config.scaling_factor
+
+ # Reshape to match latent dimensions [batch, frames, height, width, channels]
+ image_latents = image_latents.permute(0, 2, 3, 4, 1) # [batch, 1, H, W, C]
+
+ # Replace first frame with encoded image
+ latents[:, 0:1] = image_latents
+
+ if self.transformer.visual_cond:
+ # For visual conditioning, concatenate with zeros and mask
+ visual_cond = torch.zeros_like(latents)
+ visual_cond_mask = torch.zeros(
+ [
+ batch_size,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ 1,
+ ],
+ dtype=latents.dtype,
+ device=latents.device,
+ )
+
+ visual_cond_mask[:, 0:1] = 1
+ visual_cond[:, 0:1] = image_latents
+
+ latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ """Get the current guidance scale value."""
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ """Get the number of denoising timesteps."""
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ """Check if generation has been interrupted."""
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 768,
+ num_frames: int = 121,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ prompt_embeds_clip: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_clip: Optional[torch.Tensor] = None,
+ prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ negative_prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for image-to-video generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `512`):
+ The height in pixels of the generated video.
+ width (`int`, defaults to `768`):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `121`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in classifier-free guidance.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A torch generator to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds_qwen (`torch.Tensor`, *optional*):
+ Pre-generated Qwen text embeddings.
+ prompt_embeds_clip (`torch.Tensor`, *optional*):
+ Pre-generated CLIP text embeddings.
+ negative_prompt_embeds_qwen (`torch.Tensor`, *optional*):
+ Pre-generated Qwen negative text embeddings.
+ negative_prompt_embeds_clip (`torch.Tensor`, *optional*):
+ Pre-generated CLIP negative text embeddings.
+ prompt_cu_seqlens (`torch.Tensor`, *optional*):
+ Pre-generated cumulative sequence lengths for Qwen positive prompt.
+ negative_prompt_cu_seqlens (`torch.Tensor`, *optional*):
+ Pre-generated cumulative sequence lengths for Qwen negative prompt.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`KandinskyPipelineOutput`].
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function that is called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length for text encoding.
+
+ Examples:
+
+ Returns:
+ [`~KandinskyPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated videos.
+ """
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ height=height,
+ width=width,
+ prompt_embeds_qwen=prompt_embeds_qwen,
+ prompt_embeds_clip=prompt_embeds_clip,
+ negative_prompt_embeds_qwen=negative_prompt_embeds_qwen,
+ negative_prompt_embeds_clip=negative_prompt_embeds_clip,
+ prompt_cu_seqlens=prompt_cu_seqlens,
+ negative_prompt_cu_seqlens=negative_prompt_cu_seqlens,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds_qwen.shape[0]
+
+ # 3. Encode input prompt
+ if prompt_embeds_qwen is None:
+ prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if self.guidance_scale > 1.0:
+ if negative_prompt is None:
+ negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
+ elif len(negative_prompt) != len(prompt):
+ raise ValueError(
+ f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
+ )
+
+ if negative_prompt_embeds_qwen is None:
+ negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
+ self.encode_prompt(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables with image conditioning
+ num_channels_latents = self.transformer.config.in_visual_dim
+ latents = self.prepare_latents(
+ image=image,
+ batch_size=batch_size * num_videos_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ dtype=dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 6. Prepare rope positions for positional encoding
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ visual_rope_pos = [
+ torch.arange(num_latent_frames, device=device),
+ torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
+ torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
+ ]
+
+ text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
+
+ negative_text_rope_pos = (
+ torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
+ if negative_prompt_cu_seqlens is not None
+ else None
+ )
+
+ # 7. Calculate dynamic scale factor based on resolution
+ scale_factor = self._get_scale_factor(height, width)
+
+ # 8. Sparse Params for efficient attention
+ sparse_params = self.get_sparse_params(latents, device)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt)
+
+ # Predict noise residual
+ pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=prompt_embeds_qwen.to(dtype),
+ pooled_projections=prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=text_rope_pos,
+ scale_factor=scale_factor,
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None:
+ uncond_pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype),
+ pooled_projections=negative_prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=negative_text_rope_pos,
+ scale_factor=scale_factor,
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity)
+
+ latents[:, 1:, :, :, :num_channels_latents] = self.scheduler.step(
+ pred_velocity[:, 1:], t, latents[:, 1:, :, :, :num_channels_latents], return_dict=False
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen)
+ prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip)
+ negative_prompt_embeds_qwen = callback_outputs.pop(
+ "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen
+ )
+ negative_prompt_embeds_clip = callback_outputs.pop(
+ "negative_prompt_embeds_clip", negative_prompt_embeds_clip
+ )
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # 9. Post-processing - extract main latents
+ latents = latents[:, :, :, :, :num_channels_latents]
+
+ # 10. fix mesh artifacts
+ latents = self.normalize_first_frame(latents)
+
+ # 11. Decode latents to video
+ if output_type != "latent":
+ latents = latents.to(self.vae.dtype)
+ # Reshape and normalize latents
+ video = latents.reshape(
+ batch_size,
+ num_videos_per_prompt,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+ video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
+ video = video.reshape(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ # Normalize and decode through VAE
+ video = video / self.vae.config.scaling_factor
+ video = self.vae.decode(video).sample
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return KandinskyPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py
new file mode 100644
index 000000000000..bb5c40327b4e
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py
@@ -0,0 +1,818 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Callable, Dict, List, Optional, Union
+
+import numpy as np
+import regex as re
+import torch
+from torch.nn import functional as F
+from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import VaeImageProcessor
+from ...loaders import KandinskyLoraLoaderMixin
+from ...models import AutoencoderKL
+from ...models.transformers import Kandinsky5Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+
+# Add imports for offloading and tiling
+from ...utils import (
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import KandinskyImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import Kandinsky5T2IPipeline
+
+ >>> # Available models:
+ >>> # kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers
+ >>> # kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers
+
+ >>> model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers"
+ >>> pipe = Kandinsky5T2IPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt="",
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=50,
+ ... guidance_scale=3.5,
+ ... ).frames[0]
+ ```
+"""
+
+
+def basic_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Clean text using ftfy if available and unescape HTML entities.
+ """
+ if is_ftfy_available():
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Normalize whitespace in text by replacing multiple spaces with single space.
+ """
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ """
+ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py
+
+ Apply both basic cleaning and whitespace normalization to prompts.
+ """
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class Kandinsky5T2IPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using Kandinsky 5.0.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`Kandinsky5Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder Model [black-forest-labs/FLUX.1-dev
+ (vae)](https://huggingface.co/black-forest-labs/FLUX.1-dev) to encode and decode videos to and from latent
+ representations.
+ text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
+ Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct).
+ tokenizer ([`AutoProcessor`]):
+ Tokenizer for Qwen2.5-VL.
+ text_encoder_2 ([`CLIPTextModel`]):
+ Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel),
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer_2 ([`CLIPTokenizer`]):
+ Tokenizer for CLIP.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds_qwen",
+ "prompt_embeds_clip",
+ "negative_prompt_embeds_qwen",
+ "negative_prompt_embeds_clip",
+ ]
+
+ def __init__(
+ self,
+ transformer: Kandinsky5Transformer3DModel,
+ vae: AutoencoderKL,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2VLProcessor,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ transformer=transformer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ scheduler=scheduler,
+ )
+
+ self.prompt_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
+ self.prompt_template_encode_start_idx = 41
+
+ self.vae_scale_factor_spatial = 8
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.resolutions = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)]
+
+ def _encode_prompt_qwen(
+ self,
+ prompt: List[str],
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 512,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using Qwen2.5-VL text encoder.
+
+ This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for
+ image generation.
+
+ Args:
+ prompt List[str]: Input list of prompts
+ device (torch.device): Device to run encoding on
+ max_sequence_length (int): Maximum sequence length for tokenization
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ full_texts = [self.prompt_template.format(p) for p in prompt]
+ max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length
+
+ untruncated_ids = self.tokenizer(
+ text=full_texts,
+ images=None,
+ videos=None,
+ return_tensors="pt",
+ padding="longest",
+ )["input_ids"]
+
+ if untruncated_ids.shape[-1] > max_allowed_len:
+ for i, text in enumerate(full_texts):
+ tokens = untruncated_ids[i][self.prompt_template_encode_start_idx : -2]
+ removed_text = self.tokenizer.decode(tokens[max_sequence_length - 2 :])
+ if len(removed_text) > 0:
+ full_texts[i] = text[: -len(removed_text)]
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ inputs = self.tokenizer(
+ text=full_texts,
+ images=None,
+ videos=None,
+ max_length=max_allowed_len,
+ truncation=True,
+ return_tensors="pt",
+ padding=True,
+ ).to(device)
+
+ embeds = self.text_encoder(
+ input_ids=inputs["input_ids"],
+ return_dict=True,
+ output_hidden_states=True,
+ )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :]
+ attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :]
+ cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32)
+
+ return embeds.to(dtype), cu_seqlens
+
+ def _encode_prompt_clip(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using CLIP text encoder.
+
+ This method processes the input prompt through the CLIP model to generate pooled embeddings that capture
+ semantic information.
+
+ Args:
+ prompt (Union[str, List[str]]): Input prompt or list of prompts
+ device (torch.device): Device to run encoding on
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ torch.Tensor: Pooled text embeddings from CLIP
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ inputs = self.tokenizer_2(
+ prompt,
+ max_length=77,
+ truncation=True,
+ add_special_tokens=True,
+ padding="max_length",
+ return_tensors="pt",
+ ).to(device)
+
+ pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
+
+ return pooled_embed.to(dtype)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes a single prompt (positive or negative) into text encoder hidden states.
+
+ This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text
+ representations for image generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ Prompt to be encoded.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ Number of images to generate per prompt.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ Maximum sequence length for text encoding. Must be less than 1024
+ device (`torch.device`, *optional*):
+ Torch device.
+ dtype (`torch.dtype`, *optional*):
+ Torch dtype.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ - Qwen text embeddings of shape (batch_size * num_images_per_prompt, sequence_length, embedding_dim)
+ - CLIP pooled embeddings of shape (batch_size * num_images_per_prompt, clip_embedding_dim)
+ - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size *
+ num_images_per_prompt + 1,)
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ if not isinstance(prompt, list):
+ prompt = [prompt]
+
+ batch_size = len(prompt)
+
+ prompt = [prompt_clean(p) for p in prompt]
+
+ # Encode with Qwen2.5-VL
+ prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ dtype=dtype,
+ )
+ # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
+
+ # Encode with CLIP
+ prompt_embeds_clip = self._encode_prompt_clip(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ )
+ # prompt_embeds_clip shape: [batch_size, clip_embed_dim]
+
+ # Repeat embeddings for num_images_per_prompt
+ # Qwen embeddings: repeat sequence for each image, then reshape
+ prompt_embeds_qwen = prompt_embeds_qwen.repeat(
+ 1, num_images_per_prompt, 1
+ ) # [batch_size, seq_len * num_images_per_prompt, embed_dim]
+ # Reshape to [batch_size * num_images_per_prompt, seq_len, embed_dim]
+ prompt_embeds_qwen = prompt_embeds_qwen.view(
+ batch_size * num_images_per_prompt, -1, prompt_embeds_qwen.shape[-1]
+ )
+
+ # CLIP embeddings: repeat for each image
+ prompt_embeds_clip = prompt_embeds_clip.repeat(
+ 1, num_images_per_prompt, 1
+ ) # [batch_size, num_images_per_prompt, clip_embed_dim]
+ # Reshape to [batch_size * num_images_per_prompt, clip_embed_dim]
+ prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_images_per_prompt, -1)
+
+ # Repeat cumulative sequence lengths for num_images_per_prompt
+ # Original differences (lengths) for each prompt in the batch
+ original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...]
+ # Repeat the lengths for num_images_per_prompt
+ repeated_lengths = original_lengths.repeat_interleave(
+ num_images_per_prompt
+ ) # [len1, len1, ..., len2, len2, ...]
+ # Reconstruct the cumulative lengths
+ repeated_cu_seqlens = torch.cat(
+ [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]
+ )
+
+ return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds_qwen=None,
+ prompt_embeds_clip=None,
+ negative_prompt_embeds_qwen=None,
+ negative_prompt_embeds_clip=None,
+ prompt_cu_seqlens=None,
+ negative_prompt_cu_seqlens=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ """
+ Validate input parameters for the pipeline.
+
+ Args:
+ prompt: Input prompt
+ negative_prompt: Negative prompt for guidance
+ height: Image height
+ width: Image width
+ prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
+ prompt_embeds_clip: Pre-computed CLIP prompt embeddings
+ negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
+ negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
+ prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
+ negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
+ callback_on_step_end_tensor_inputs: Callback tensor inputs
+
+ Raises:
+ ValueError: If inputs are invalid
+ """
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError("max_sequence_length must be less than 1024")
+
+ if (width, height) not in self.resolutions:
+ resolutions_str = ",".join([f"({w},{h})" for w, h in self.resolutions])
+ logger.warning(
+ f"`height` and `width` have to be one of {resolutions_str}, but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ # Check for consistency within positive prompt embeddings and sequence lengths
+ if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None:
+ if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None:
+ raise ValueError(
+ "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check for consistency within negative prompt embeddings and sequence lengths
+ if (
+ negative_prompt_embeds_qwen is not None
+ or negative_prompt_embeds_clip is not None
+ or negative_prompt_cu_seqlens is not None
+ ):
+ if (
+ negative_prompt_embeds_qwen is None
+ or negative_prompt_embeds_clip is None
+ or negative_prompt_cu_seqlens is None
+ ):
+ raise ValueError(
+ "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive)
+ if prompt is None and prompt_embeds_qwen is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined."
+ )
+
+ # Validate types for prompt and negative_prompt if provided
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 1024,
+ width: int = 1024,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Prepare initial latent variables for text-to-image generation.
+
+ This method creates random noise latents
+
+ Args:
+ batch_size (int): Number of images to generate
+ num_channels_latents (int): Number of channels in latent space
+ height (int): Height of generated image
+ width (int): Width of generated image
+ dtype (torch.dtype): Data type for latents
+ device (torch.device): Device to create latents on
+ generator (torch.Generator): Random number generator
+ latents (torch.Tensor): Pre-existing latents to use
+
+ Returns:
+ torch.Tensor: Prepared latent tensor
+ """
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ 1,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # Generate random noise
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ """Get the current guidance scale value."""
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ """Get the number of denoising timesteps."""
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ """Check if generation has been interrupted."""
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ prompt_embeds_clip: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_clip: Optional[torch.Tensor] = None,
+ prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ negative_prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for text-to-image generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `1024`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `1024`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in classifier-free guidance.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A torch generator to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds_qwen (`torch.Tensor`, *optional*):
+ Pre-generated Qwen text embeddings.
+ prompt_embeds_clip (`torch.Tensor`, *optional*):
+ Pre-generated CLIP text embeddings.
+ negative_prompt_embeds_qwen (`torch.Tensor`, *optional*):
+ Pre-generated Qwen negative text embeddings.
+ negative_prompt_embeds_clip (`torch.Tensor`, *optional*):
+ Pre-generated CLIP negative text embeddings.
+ prompt_cu_seqlens (`torch.Tensor`, *optional*):
+ Pre-generated cumulative sequence lengths for Qwen positive prompt.
+ negative_prompt_cu_seqlens (`torch.Tensor`, *optional*):
+ Pre-generated cumulative sequence lengths for Qwen negative prompt.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`KandinskyImagePipelineOutput`].
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function that is called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length for text encoding.
+
+ Examples:
+
+ Returns:
+ [`~KandinskyImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`KandinskyImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images.
+ """
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ prompt_embeds_qwen=prompt_embeds_qwen,
+ prompt_embeds_clip=prompt_embeds_clip,
+ negative_prompt_embeds_qwen=negative_prompt_embeds_qwen,
+ negative_prompt_embeds_clip=negative_prompt_embeds_clip,
+ prompt_cu_seqlens=prompt_cu_seqlens,
+ negative_prompt_cu_seqlens=negative_prompt_cu_seqlens,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+ if (width, height) not in self.resolutions:
+ width, height = self.resolutions[
+ np.argmin([abs((i[0] / i[1]) - (width / height)) for i in self.resolutions])
+ ]
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds_qwen.shape[0]
+
+ # 3. Encode input prompt
+ if prompt_embeds_qwen is None:
+ prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if self.guidance_scale > 1.0:
+ if negative_prompt is None:
+ negative_prompt = ""
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
+ elif len(negative_prompt) != len(prompt):
+ raise ValueError(
+ f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
+ )
+
+ if negative_prompt_embeds_qwen is None:
+ negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
+ self.encode_prompt(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_visual_dim
+ latents = self.prepare_latents(
+ batch_size=batch_size * num_images_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 6. Prepare rope positions for positional encoding
+ visual_rope_pos = [
+ torch.arange(1, device=device),
+ torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
+ torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
+ ]
+
+ text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
+
+ negative_text_rope_pos = (
+ torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
+ if negative_prompt_cu_seqlens is not None
+ else None
+ )
+
+ # 7. Calculate dynamic scale factor based on resolution
+ scale_factor = [1.0, 1.0, 1.0]
+
+ # 8. Sparse Params for efficient attention
+ sparse_params = None
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ timestep = t.unsqueeze(0).repeat(batch_size * num_images_per_prompt)
+
+ # Predict noise residual
+ pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=prompt_embeds_qwen.to(dtype),
+ pooled_projections=prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=text_rope_pos,
+ scale_factor=scale_factor,
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None:
+ uncond_pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype),
+ pooled_projections=negative_prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=negative_text_rope_pos,
+ scale_factor=scale_factor,
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity)
+
+ latents = self.scheduler.step(pred_velocity[:, :], t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen)
+ prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip)
+ negative_prompt_embeds_qwen = callback_outputs.pop(
+ "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen
+ )
+ negative_prompt_embeds_clip = callback_outputs.pop(
+ "negative_prompt_embeds_clip", negative_prompt_embeds_clip
+ )
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # 9. Post-processing - extract main latents
+ latents = latents[:, :, :, :, :num_channels_latents]
+
+ # 10. Decode latents to image
+ if output_type != "latent":
+ latents = latents.to(self.vae.dtype)
+ # Reshape and normalize latents
+ latents = latents.reshape(
+ batch_size,
+ num_images_per_prompt,
+ 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+ latents = latents.permute(0, 1, 5, 2, 3, 4) # [batch, num_images, channels, 1, height, width]
+ latents = latents.reshape(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ # Normalize and decode through VAE
+ latents = latents / self.vae.config.scaling_factor
+ image = self.vae.decode(latents).sample
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return KandinskyImagePipelineOutput(image=image)
diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_output.py b/src/diffusers/pipelines/kandinsky5/pipeline_output.py
new file mode 100644
index 000000000000..2172ddff7e22
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/pipeline_output.py
@@ -0,0 +1,35 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class KandinskyPipelineOutput(BaseOutput):
+ r"""
+ Output class for kandinsky video pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
+
+
+@dataclass
+class KandinskyImagePipelineOutput(BaseOutput):
+ r"""
+ Output class for kandinsky image pipelines.
+
+ Args:
+ image (`torch.Tensor`, `np.ndarray`, or List[PIL.Image.Image]):
+ List of image outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image. It can also be a NumPy array or Torch tensor of shape `(batch_size, channels, height,
+ width)`.
+ """
+
+ image: torch.Tensor
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py
index 1fc4c02cc43f..7c8468bcb109 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,9 +21,8 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import KolorsPipelineOutput
@@ -436,7 +435,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -580,22 +579,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -633,7 +622,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -729,11 +718,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -741,15 +730,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
index df94ec3f0f24..10a7962c258c 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,10 +22,9 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
-from ...utils.torch_utils import randn_tensor
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import KolorsPipelineOutput
from .text_encoder import ChatGLMModel
@@ -456,7 +455,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -618,7 +617,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
image = image.to(device=device, dtype=dtype)
@@ -708,22 +707,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -761,7 +750,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -880,11 +869,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -892,15 +881,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/kolors/text_encoder.py b/src/diffusers/pipelines/kolors/text_encoder.py
index 757569c880c0..6fd17156a116 100644
--- a/src/diffusers/pipelines/kolors/text_encoder.py
+++ b/src/diffusers/pipelines/kolors/text_encoder.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -434,7 +434,7 @@ def __init__(self, config: ChatGLMConfig, device=None):
self.add_bias = config.add_bias_linear
- # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+ # Project to 4h. If using swiglu double the output width, see https://huggingface.co/papers/2002.05202
self.dense_h_to_4h = nn.Linear(
config.hidden_size,
config.ffn_hidden_size * 2,
@@ -668,7 +668,7 @@ def forward(self, input_ids):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
+ # Data format change to avoid explicit transposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
diff --git a/src/diffusers/pipelines/kolors/tokenizer.py b/src/diffusers/pipelines/kolors/tokenizer.py
index fa241b920c97..b824ba12e079 100644
--- a/src/diffusers/pipelines/kolors/tokenizer.py
+++ b/src/diffusers/pipelines/kolors/tokenizer.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ChatGLM3-6B Model Team, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index 1c59ca7d6d7c..59f733a498ed 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -186,8 +186,8 @@ class LatentConsistencyModelImg2ImgPipeline(
supports [`LCMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
requires_safety_checker (`bool`, *optional*, defaults to `True`):
@@ -607,7 +607,7 @@ def get_guidance_scale_embedding(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
index a3d9917d3376..e463884618f5 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -165,8 +165,8 @@ class LatentConsistencyModelPipeline(
supports [`LCMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
requires_safety_checker (`bool`, *optional*, defaults to `True`):
@@ -548,7 +548,7 @@ def get_guidance_scale_embedding(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
index c7aa76a01fb8..f1bf4701e31f 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
index 879722e6a0e2..631539e5c667 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
@@ -4,7 +4,6 @@
import numpy as np
import PIL.Image
import torch
-import torch.utils.checkpoint
from ...models import UNet2DModel, VQModel
from ...schedulers import (
@@ -95,8 +94,8 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -166,7 +165,7 @@ def __call__(
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py
index e9a95e8be45c..4d42a7049ec9 100644
--- a/src/diffusers/pipelines/latte/pipeline_latte.py
+++ b/src/diffusers/pipelines/latte/pipeline_latte.py
@@ -1,4 +1,4 @@
-# Copyright 2024 the Latte Team and The HuggingFace Team.
+# Copyright 2025 the Latte Team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -356,7 +356,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -501,7 +501,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -592,7 +592,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -657,11 +657,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
- usually at the expense of lower video quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to
+ the text `prompt`, usually at the expense of lower video quality.
video_length (`int`, *optional*, defaults to 16):
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
num_images_per_prompt (`int`, *optional*, defaults to 1):
@@ -671,15 +671,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated video.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -747,7 +747,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
index bdac47c47ade..fbf4dc23d043 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
@@ -49,7 +49,7 @@
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
@@ -244,7 +244,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -381,8 +381,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -439,7 +439,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, eta, generator=None):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -722,6 +722,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -729,6 +735,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -737,6 +749,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -744,6 +762,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@torch.no_grad()
@@ -808,7 +832,7 @@ def __call__(
edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
Guidance scale for guiding the image generation. If provided as list values should correspond to
`editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++
- Paper](https://arxiv.org/abs/2301.12247).
+ Paper](https://huggingface.co/papers/2301.12247).
edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
Number of diffusion steps (for each prompt) for which guidance will not be applied.
edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
@@ -816,7 +840,7 @@ def __call__(
edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
Masking threshold of guidance. Threshold should be proportional to the image region that is modified.
'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++
- Paper](https://arxiv.org/abs/2301.12247).
+ Paper](https://huggingface.co/papers/2301.12247).
user_mask (`torch.Tensor`, *optional*):
User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s
implicit masks do not meet user preferences.
@@ -826,11 +850,11 @@ def __call__(
use_cross_attn_mask (`bool`, defaults to `False`):
Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask
is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++
- paper](https://arxiv.org/pdf/2311.16711.pdf).
+ paper](https://huggingface.co/papers/2311.16711).
use_intersect_mask (`bool`, defaults to `True`):
Whether the masking term is calculated as intersection of cross-attention masks and masks derived from
the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate
- are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
+ are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://huggingface.co/papers/2311.16711).
attn_store_steps (`List[int]`, *optional*):
Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes.
store_averaged_over_steps (`bool`, defaults to `True`):
@@ -841,7 +865,7 @@ def __call__(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -1191,7 +1215,7 @@ def __call__(
noise_pred = noise_pred_uncond + noise_guidance_edit
if enable_edit_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_edit_concepts.mean(dim=0, keepdim=False),
@@ -1268,8 +1292,8 @@ def invert(
):
r"""
The function to the pipeline for image inversion as described by the [LEDITS++
- Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the
- inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead.
+ Paper](https://huggingface.co/papers/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the
+ inversion proposed by [edit-friendly DPDM](https://huggingface.co/papers/2304.06140) will be performed instead.
Args:
image (`PipelineImageInput`):
@@ -1443,7 +1467,7 @@ def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, e
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
@@ -1455,10 +1479,10 @@ def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, e
variance = scheduler._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
- # modifed so that updated xtm1 is returned as well (to avoid error accumulation)
+ # modified so that updated xtm1 is returned as well (to avoid error accumulation)
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if variance > 0.0:
noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta)
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
index cad7d8a66a08..993957a052fc 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
@@ -37,13 +37,12 @@
from ...models.attention_processor import (
Attention,
AttnProcessor,
- AttnProcessor2_0,
- XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
@@ -622,7 +621,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, eta, generator=None):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -687,21 +686,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -747,7 +737,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -770,6 +760,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -777,6 +773,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -785,6 +787,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -792,6 +800,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
@@ -901,9 +915,10 @@ def __call__(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
@@ -929,7 +944,7 @@ def __call__(
edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
Guidance scale for guiding the image generation. If provided as list values should correspond to
`editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++
- Paper](https://arxiv.org/abs/2301.12247).
+ Paper](https://huggingface.co/papers/2301.12247).
edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
Number of diffusion steps (for each prompt) for which guidance is not applied.
edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
@@ -937,18 +952,18 @@ def __call__(
edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
Masking threshold of guidance. Threshold should be proportional to the image region that is modified.
'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++
- Paper](https://arxiv.org/abs/2301.12247).
+ Paper](https://huggingface.co/papers/2301.12247).
sem_guidance (`List[torch.Tensor]`, *optional*):
List of pre-generated guidance vectors to be applied at generation. Length of the list has to
correspond to `num_inference_steps`.
use_cross_attn_mask:
Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask
is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++
- paper](https://arxiv.org/pdf/2311.16711.pdf).
+ paper](https://huggingface.co/papers/2311.16711).
use_intersect_mask:
Whether the masking term is calculated as intersection of cross-attention masks and masks derived from
the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate
- are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
+ are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://huggingface.co/papers/2311.16711).
user_mask:
User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s
implicit masks do not meet user preferences.
@@ -1350,7 +1365,7 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
if enable_edit_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_edit_concepts.mean(dim=0, keepdim=False),
@@ -1478,8 +1493,8 @@ def invert(
):
r"""
The function to the pipeline for image inversion as described by the [LEDITS++
- Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the
- inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead.
+ Paper](https://huggingface.co/papers/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the
+ inversion proposed by [edit-friendly DPDM](https://huggingface.co/papers/2304.06140) will be performed instead.
Args:
image (`PipelineImageInput`):
@@ -1691,7 +1706,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -1727,7 +1742,7 @@ def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, e
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
# 4. Clip "predicted x_0"
@@ -1739,10 +1754,10 @@ def compute_noise_ddim(scheduler, prev_latents, latents, timestep, noise_pred, e
variance = scheduler._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * noise_pred
- # modifed so that updated xtm1 is returned as well (to avoid error accumulation)
+ # modified so that updated xtm1 is returned as well (to avoid error accumulation)
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if variance > 0.0:
noise = (prev_latents - mu_xt) / (variance ** (0.5) * eta)
diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py
index 199e730d9b4d..6001867916b3 100644
--- a/src/diffusers/pipelines/ltx/__init__.py
+++ b/src/diffusers/pipelines/ltx/__init__.py
@@ -22,9 +22,11 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
+ _import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
+ _import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -34,9 +36,11 @@
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
+ from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
+ from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
else:
import sys
diff --git a/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py b/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py
new file mode 100644
index 000000000000..6dce792a2b43
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/modeling_latent_upsampler.py
@@ -0,0 +1,188 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.modeling_utils import ModelMixin
+
+
+class ResBlock(torch.nn.Module):
+ def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
+ super().__init__()
+ if mid_channels is None:
+ mid_channels = channels
+
+ Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
+
+ self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
+ self.norm1 = torch.nn.GroupNorm(32, mid_channels)
+ self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
+ self.norm2 = torch.nn.GroupNorm(32, channels)
+ self.activation = torch.nn.SiLU()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.activation(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.activation(hidden_states + residual)
+ return hidden_states
+
+
+class PixelShuffleND(torch.nn.Module):
+ def __init__(self, dims, upscale_factors=(2, 2, 2)):
+ super().__init__()
+
+ self.dims = dims
+ self.upscale_factors = upscale_factors
+
+ if dims not in [1, 2, 3]:
+ raise ValueError("dims must be 1, 2, or 3")
+
+ def forward(self, x):
+ if self.dims == 3:
+ # spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
+ return (
+ x.unflatten(1, (-1, *self.upscale_factors[:3]))
+ .permute(0, 1, 5, 2, 6, 3, 7, 4)
+ .flatten(6, 7)
+ .flatten(4, 5)
+ .flatten(2, 3)
+ )
+ elif self.dims == 2:
+ # spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
+ return (
+ x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
+ )
+ elif self.dims == 1:
+ # temporal: b (c p1) f h w -> b c (f p1) h w
+ return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
+
+
+class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin):
+ """
+ Model to spatially upsample VAE latents.
+
+ Args:
+ in_channels (`int`, defaults to `128`):
+ Number of channels in the input latent
+ mid_channels (`int`, defaults to `512`):
+ Number of channels in the middle layers
+ num_blocks_per_stage (`int`, defaults to `4`):
+ Number of ResBlocks to use in each stage (pre/post upsampling)
+ dims (`int`, defaults to `3`):
+ Number of dimensions for convolutions (2 or 3)
+ spatial_upsample (`bool`, defaults to `True`):
+ Whether to spatially upsample the latent
+ temporal_upsample (`bool`, defaults to `False`):
+ Whether to temporally upsample the latent
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 128,
+ mid_channels: int = 512,
+ num_blocks_per_stage: int = 4,
+ dims: int = 3,
+ spatial_upsample: bool = True,
+ temporal_upsample: bool = False,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.mid_channels = mid_channels
+ self.num_blocks_per_stage = num_blocks_per_stage
+ self.dims = dims
+ self.spatial_upsample = spatial_upsample
+ self.temporal_upsample = temporal_upsample
+
+ ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
+
+ self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
+ self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
+ self.initial_activation = torch.nn.SiLU()
+
+ self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
+
+ if spatial_upsample and temporal_upsample:
+ self.upsampler = torch.nn.Sequential(
+ torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
+ PixelShuffleND(3),
+ )
+ elif spatial_upsample:
+ self.upsampler = torch.nn.Sequential(
+ torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
+ PixelShuffleND(2),
+ )
+ elif temporal_upsample:
+ self.upsampler = torch.nn.Sequential(
+ torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
+ PixelShuffleND(1),
+ )
+ else:
+ raise ValueError("Either spatial_upsample or temporal_upsample must be True")
+
+ self.post_upsample_res_blocks = torch.nn.ModuleList(
+ [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
+ )
+
+ self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+
+ if self.dims == 2:
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ hidden_states = self.initial_conv(hidden_states)
+ hidden_states = self.initial_norm(hidden_states)
+ hidden_states = self.initial_activation(hidden_states)
+
+ for block in self.res_blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.upsampler(hidden_states)
+
+ for block in self.post_upsample_res_blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.final_conv(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+ else:
+ hidden_states = self.initial_conv(hidden_states)
+ hidden_states = self.initial_norm(hidden_states)
+ hidden_states = self.initial_activation(hidden_states)
+
+ for block in self.res_blocks:
+ hidden_states = block(hidden_states)
+
+ if self.temporal_upsample:
+ hidden_states = self.upsampler(hidden_states)
+ hidden_states = hidden_states[:, :, 1:, :, :]
+ else:
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
+ hidden_states = self.upsampler(hidden_states)
+ hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
+
+ for block in self.post_upsample_res_blocks:
+ hidden_states = block(hidden_states)
+
+ hidden_states = self.final_conv(hidden_states)
+
+ return hidden_states
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 6f3faed8ff72..8ca8b4419e18 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -140,6 +140,33 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
@@ -481,6 +508,10 @@ def prepare_latents(
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@@ -514,6 +545,7 @@ def __call__(
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
+ guidance_rescale: float = 0.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -551,11 +583,17 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `3 `):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -564,7 +602,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -624,6 +662,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -719,24 +758,31 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- rope_interpolation_scale=rope_interpolation_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ if self.guidance_rescale > 0:
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -789,6 +835,7 @@ def __call__(
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+ latents = latents.to(self.vae.dtype)
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
index ef1fd568397f..48a6f0837c8d 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import inspect
from dataclasses import dataclass
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
@@ -75,6 +75,7 @@
>>> # Generate video
>>> generator = torch.Generator("cuda").manual_seed(0)
+ >>> # Text-only conditioning is also supported without the need to pass `conditions`
>>> video = pipe(
... conditions=[condition1, condition2],
... prompt=prompt,
@@ -221,9 +222,36 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
- Pipeline for image-to-video generation.
+ Pipeline for text/image/video-to-video generation.
Reference: https://github.com/Lightricks/LTX-Video
@@ -429,6 +457,7 @@ def check_inputs(
video,
frame_index,
strength,
+ denoise_strength,
height,
width,
callback_on_step_end_tensor_inputs=None,
@@ -482,9 +511,6 @@ def check_inputs(
if conditions is not None and (image is not None or video is not None):
raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
- if conditions is None and (image is None and video is None):
- raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")
-
if conditions is None:
if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
raise ValueError(
@@ -499,6 +525,9 @@ def check_inputs(
elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
+ if denoise_strength < 0 or denoise_strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}")
+
@staticmethod
def _prepare_video_ids(
batch_size: int,
@@ -642,97 +671,113 @@ def add_noise_to_image_conditioning_latents(
def prepare_latents(
self,
- conditions: List[torch.Tensor],
- condition_strength: List[float],
- condition_frame_index: List[int],
+ conditions: Optional[List[torch.Tensor]] = None,
+ condition_strength: Optional[List[float]] = None,
+ condition_frame_index: Optional[List[int]] = None,
batch_size: int = 1,
num_channels_latents: int = 128,
height: int = 512,
width: int = 704,
num_frames: int = 161,
num_prefix_latent_frames: int = 2,
+ sigma: Optional[torch.Tensor] = None,
+ latents: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
- ) -> None:
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
-
- condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)
-
- extra_conditioning_latents = []
- extra_conditioning_video_ids = []
- extra_conditioning_mask = []
- extra_conditioning_num_latents = 0
- for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
- condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
- condition_latents = self._normalize_latents(
- condition_latents, self.vae.latents_mean, self.vae.latents_std
- ).to(device, dtype=dtype)
-
- num_data_frames = data.size(2)
- num_cond_frames = condition_latents.size(2)
-
- if frame_index == 0:
- latents[:, :, :num_cond_frames] = torch.lerp(
- latents[:, :, :num_cond_frames], condition_latents, strength
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ if latents is not None and sigma is not None:
+ if latents.shape != shape:
+ raise ValueError(
+ f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input."
)
- condition_latent_frames_mask[:, :num_cond_frames] = strength
+ latents = latents.to(device=device, dtype=dtype)
+ sigma = sigma.to(device=device, dtype=dtype)
+ latents = sigma * noise + (1 - sigma) * latents
+ else:
+ latents = noise
- else:
- if num_data_frames > 1:
- if num_cond_frames < num_prefix_latent_frames:
- raise ValueError(
- f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
- )
+ if len(conditions) > 0:
+ condition_latent_frames_mask = torch.zeros(
+ (batch_size, num_latent_frames), device=device, dtype=torch.float32
+ )
- if num_cond_frames > num_prefix_latent_frames:
- start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
- end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
- latents[:, :, start_frame:end_frame] = torch.lerp(
- latents[:, :, start_frame:end_frame],
- condition_latents[:, :, num_prefix_latent_frames:],
- strength,
- )
- condition_latent_frames_mask[:, start_frame:end_frame] = strength
- condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
-
- noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
- condition_latents = torch.lerp(noise, condition_latents, strength)
-
- condition_video_ids = self._prepare_video_ids(
- batch_size,
- condition_latents.size(2),
- latent_height,
- latent_width,
- patch_size=self.transformer_spatial_patch_size,
- patch_size_t=self.transformer_temporal_patch_size,
- device=device,
- )
- condition_video_ids = self._scale_video_ids(
- condition_video_ids,
- scale_factor=self.vae_spatial_compression_ratio,
- scale_factor_t=self.vae_temporal_compression_ratio,
- frame_index=frame_index,
- device=device,
- )
- condition_latents = self._pack_latents(
- condition_latents,
- self.transformer_spatial_patch_size,
- self.transformer_temporal_patch_size,
- )
- condition_conditioning_mask = torch.full(
- condition_latents.shape[:2], strength, device=device, dtype=dtype
- )
+ extra_conditioning_latents = []
+ extra_conditioning_video_ids = []
+ extra_conditioning_mask = []
+ extra_conditioning_num_latents = 0
+ for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
+ condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
+ condition_latents = self._normalize_latents(
+ condition_latents, self.vae.latents_mean, self.vae.latents_std
+ ).to(device, dtype=dtype)
+
+ num_data_frames = data.size(2)
+ num_cond_frames = condition_latents.size(2)
+
+ if frame_index == 0:
+ latents[:, :, :num_cond_frames] = torch.lerp(
+ latents[:, :, :num_cond_frames], condition_latents, strength
+ )
+ condition_latent_frames_mask[:, :num_cond_frames] = strength
+
+ else:
+ if num_data_frames > 1:
+ if num_cond_frames < num_prefix_latent_frames:
+ raise ValueError(
+ f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
+ )
+
+ if num_cond_frames > num_prefix_latent_frames:
+ start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
+ end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
+ latents[:, :, start_frame:end_frame] = torch.lerp(
+ latents[:, :, start_frame:end_frame],
+ condition_latents[:, :, num_prefix_latent_frames:],
+ strength,
+ )
+ condition_latent_frames_mask[:, start_frame:end_frame] = strength
+ condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
+
+ noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
+ condition_latents = torch.lerp(noise, condition_latents, strength)
+
+ condition_video_ids = self._prepare_video_ids(
+ batch_size,
+ condition_latents.size(2),
+ latent_height,
+ latent_width,
+ patch_size=self.transformer_spatial_patch_size,
+ patch_size_t=self.transformer_temporal_patch_size,
+ device=device,
+ )
+ condition_video_ids = self._scale_video_ids(
+ condition_video_ids,
+ scale_factor=self.vae_spatial_compression_ratio,
+ scale_factor_t=self.vae_temporal_compression_ratio,
+ frame_index=frame_index,
+ device=device,
+ )
+ condition_latents = self._pack_latents(
+ condition_latents,
+ self.transformer_spatial_patch_size,
+ self.transformer_temporal_patch_size,
+ )
+ condition_conditioning_mask = torch.full(
+ condition_latents.shape[:2], strength, device=device, dtype=dtype
+ )
- extra_conditioning_latents.append(condition_latents)
- extra_conditioning_video_ids.append(condition_video_ids)
- extra_conditioning_mask.append(condition_conditioning_mask)
- extra_conditioning_num_latents += condition_latents.size(1)
+ extra_conditioning_latents.append(condition_latents)
+ extra_conditioning_video_ids.append(condition_video_ids)
+ extra_conditioning_mask.append(condition_conditioning_mask)
+ extra_conditioning_num_latents += condition_latents.size(1)
video_ids = self._prepare_video_ids(
batch_size,
@@ -743,7 +788,10 @@ def prepare_latents(
patch_size=self.transformer_spatial_patch_size,
device=device,
)
- conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
+ if len(conditions) > 0:
+ conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
+ else:
+ conditioning_mask, extra_conditioning_num_latents = None, 0
video_ids = self._scale_video_ids(
video_ids,
scale_factor=self.vae_spatial_compression_ratio,
@@ -755,17 +803,28 @@ def prepare_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
- if len(extra_conditioning_latents) > 0:
+ if len(conditions) > 0 and len(extra_conditioning_latents) > 0:
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
+ def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength):
+ num_steps = min(int(num_inference_steps * strength), num_inference_steps)
+ start_index = max(num_inference_steps - num_steps, 0)
+ sigmas = sigmas[start_index:]
+ timesteps = timesteps[start_index:]
+ return sigmas, timesteps, num_inference_steps - start_index
+
@property
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@@ -795,6 +854,7 @@ def __call__(
video: List[PipelineImageInput] = None,
frame_index: Union[int, List[int]] = 0,
strength: Union[float, List[float]] = 1.0,
+ denoise_strength: float = 1.0,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
@@ -804,6 +864,7 @@ def __call__(
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
+ guidance_rescale: float = 0.0,
image_cond_noise_scale: float = 0.15,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -838,6 +899,10 @@ def __call__(
generation. If not provided, one has to pass `conditions`.
strength (`float` or `List[float]`, *optional*):
The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
+ denoise_strength (`float`, defaults to `1.0`):
+ The strength of the noise added to the latents for editing. Higher strength leads to more noise added
+ to the latents, therefore leading to more differences between original video and generated video. This
+ is useful for video-to-video editing.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
@@ -855,11 +920,17 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `3 `):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -868,7 +939,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -914,8 +985,6 @@ def __call__(
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
- if latents is not None:
- raise ValueError("Passing latents is not yet supported.")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -925,6 +994,7 @@ def __call__(
video=video,
frame_index=frame_index,
strength=strength,
+ denoise_strength=denoise_strength,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -935,6 +1005,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -955,7 +1026,7 @@ def __call__(
frame_index = [condition.frame_index for condition in conditions]
image = [condition.image for condition in conditions]
video = [condition.video for condition in conditions]
- else:
+ elif image is not None or video is not None:
if not isinstance(image, list):
image = [image]
num_conditions = 1
@@ -973,8 +1044,9 @@ def __call__(
strength = [strength] * num_conditions
device = self._execution_device
+ vae_dtype = self.vae.dtype
- # 3. Prepare text embeddings
+ # 3. Prepare text embeddings & conditioning image/video
(
prompt_embeds,
prompt_attention_mask,
@@ -996,37 +1068,57 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
- vae_dtype = self.vae.dtype
-
conditioning_tensors = []
- for condition_image, condition_video, condition_frame_index, condition_strength in zip(
- image, video, frame_index, strength
- ):
- if condition_image is not None:
- condition_tensor = (
- self.video_processor.preprocess(condition_image, height, width)
- .unsqueeze(2)
- .to(device, dtype=vae_dtype)
- )
- elif condition_video is not None:
- condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
- num_frames_input = condition_tensor.size(2)
- num_frames_output = self.trim_conditioning_sequence(
- condition_frame_index, num_frames_input, num_frames
- )
- condition_tensor = condition_tensor[:, :, :num_frames_output]
- condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
- else:
- raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
+ is_conditioning_image_or_video = image is not None or video is not None
+ if is_conditioning_image_or_video:
+ for condition_image, condition_video, condition_frame_index, condition_strength in zip(
+ image, video, frame_index, strength
+ ):
+ if condition_image is not None:
+ condition_tensor = (
+ self.video_processor.preprocess(condition_image, height, width)
+ .unsqueeze(2)
+ .to(device, dtype=vae_dtype)
+ )
+ elif condition_video is not None:
+ condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
+ num_frames_input = condition_tensor.size(2)
+ num_frames_output = self.trim_conditioning_sequence(
+ condition_frame_index, num_frames_input, num_frames
+ )
+ condition_tensor = condition_tensor[:, :, :num_frames_output]
+ condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
+ else:
+ raise ValueError("Either `image` or `video` must be provided for conditioning.")
+
+ if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
+ raise ValueError(
+ f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
+ f"but got {condition_tensor.size(2)} frames."
+ )
+ conditioning_tensors.append(condition_tensor)
- if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
- raise ValueError(
- f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
- f"but got {condition_tensor.size(2)} frames."
- )
- conditioning_tensors.append(condition_tensor)
+ # 4. Prepare timesteps
+ latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
+ latent_height = height // self.vae_spatial_compression_ratio
+ latent_width = width // self.vae_spatial_compression_ratio
+ if timesteps is None:
+ sigmas = linear_quadratic_schedule(num_inference_steps)
+ timesteps = sigmas * 1000
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ sigmas = self.scheduler.sigmas
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ latent_sigma = None
+ if denoise_strength < 1:
+ sigmas, timesteps, num_inference_steps = self.get_timesteps(
+ sigmas, timesteps, num_inference_steps, denoise_strength
+ )
+ latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt)
+
+ self._num_timesteps = len(timesteps)
- # 4. Prepare latent variables
+ # 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
conditioning_tensors,
@@ -1037,6 +1129,8 @@ def __call__(
height=height,
width=width,
num_frames=num_frames,
+ sigma=latent_sigma,
+ latents=latents,
generator=generator,
device=device,
dtype=torch.float32,
@@ -1045,27 +1139,12 @@ def __call__(
video_coords = video_coords.float()
video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
- init_latents = latents.clone()
+ init_latents = latents.clone() if is_conditioning_image_or_video else None
if self.do_classifier_free_guidance:
video_coords = torch.cat([video_coords, video_coords], dim=0)
- # 5. Prepare timesteps
- latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
- latent_height = height // self.vae_spatial_compression_ratio
- latent_width = width // self.vae_spatial_compression_ratio
- sigmas = linear_quadratic_schedule(num_inference_steps)
- timesteps = sigmas * 1000
- timesteps, num_inference_steps = retrieve_timesteps(
- self.scheduler,
- num_inference_steps,
- device,
- timesteps=timesteps,
- )
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
- self._num_timesteps = len(timesteps)
-
- # 7. Denoising loop
+ # 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
@@ -1073,7 +1152,7 @@ def __call__(
self._current_timestep = t
- if image_cond_noise_scale > 0:
+ if image_cond_noise_scale > 0 and init_latents is not None:
# Add timestep-dependent noise to the hard-conditioning latents
# This helps with motion continuity, especially when conditioned on a single frame
latents = self.add_noise_to_image_conditioning_latents(
@@ -1086,37 +1165,49 @@ def __call__(
)
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
- conditioning_mask_model_input = (
- torch.cat([conditioning_mask, conditioning_mask])
- if self.do_classifier_free_guidance
- else conditioning_mask
- )
+ if is_conditioning_image_or_video:
+ conditioning_mask_model_input = (
+ torch.cat([conditioning_mask, conditioning_mask])
+ if self.do_classifier_free_guidance
+ else conditioning_mask
+ )
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
- timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
-
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- video_coords=video_coords,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ if is_conditioning_image_or_video:
+ timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
+
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ video_coords=video_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
timestep, _ = timestep.chunk(2)
+ if self.guidance_rescale > 0:
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
denoised_latents = self.scheduler.step(
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
)[0]
- tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
- latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
+ if is_conditioning_image_or_video:
+ tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
+ latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
+ else:
+ latents = denoised_latents
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -1134,7 +1225,9 @@ def __call__(
if XLA_AVAILABLE:
xm.mark_step()
- latents = latents[:, extra_conditioning_num_latents:]
+ if is_conditioning_image_or_video:
+ latents = latents[:, extra_conditioning_num_latents:]
+
latents = self._unpack_latents(
latents,
latent_num_frames,
@@ -1155,7 +1248,7 @@ def __call__(
if not self.vae.config.timestep_conditioning:
timestep = None
else:
- noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 1ae67967c6f5..f30f8a3dc8f6 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -159,6 +159,33 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for image-to-video generation.
@@ -542,6 +569,10 @@ def prepare_latents(
def guidance_scale(self):
return self._guidance_scale
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@@ -576,6 +607,7 @@ def __call__(
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
+ guidance_rescale: float = 0.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -615,11 +647,17 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `3 `):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -628,7 +666,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -688,6 +726,7 @@ def __call__(
)
self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -792,18 +831,19 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- rope_interpolation_scale=rope_interpolation_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
@@ -811,6 +851,12 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
timestep, _ = timestep.chunk(2)
+ if self.guidance_rescale > 0:
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
+ noise_pred = rescale_noise_cfg(
+ noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
+ )
+
# compute the previous noisy sample x_t -> x_t-1
noise_pred = self._unpack_latents(
noise_pred,
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
new file mode 100644
index 000000000000..9acff105e56d
--- /dev/null
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
@@ -0,0 +1,341 @@
+# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import torch
+
+from ...image_processor import PipelineImageInput
+from ...models import AutoencoderKLLTXVideo
+from ...utils import deprecate, get_logger
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .modeling_latent_upsampler import LTXLatentUpsamplerModel
+from .pipeline_output import LTXPipelineOutput
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class LTXLatentUpsamplePipeline(DiffusionPipeline):
+ model_cpu_offload_seq = ""
+
+ def __init__(
+ self,
+ vae: AutoencoderKLLTXVideo,
+ latent_upsampler: LTXLatentUpsamplerModel,
+ ) -> None:
+ super().__init__()
+
+ self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
+
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
+
+ def prepare_latents(
+ self,
+ video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ video = video.to(device=device, dtype=self.vae.dtype)
+ if isinstance(generator, list):
+ if len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ init_latents = [
+ retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
+ ]
+ else:
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
+
+ init_latents = torch.cat(init_latents, dim=0).to(dtype)
+ init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
+ return init_latents
+
+ def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
+ """
+ Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
+ tensor.
+
+ Args:
+ latent (`torch.Tensor`):
+ Input latents to normalize
+ reference_latents (`torch.Tensor`):
+ The reference latents providing style statistics.
+ factor (`float`):
+ Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
+
+ Returns:
+ torch.Tensor: The transformed latent tensor
+ """
+ result = latents.clone()
+
+ for i in range(latents.size(0)):
+ for c in range(latents.size(1)):
+ r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
+ i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
+
+ result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
+
+ result = torch.lerp(latents, result, factor)
+ return result
+
+ def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
+ """
+ Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
+ smooth way using a sigmoid-based compression.
+
+ This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
+ when controlling dynamic behavior with a `compression` factor.
+
+ Args:
+ latents : torch.Tensor
+ Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
+ compression : float
+ Compression strength in the range [0, 1].
+ - 0.0: No tone-mapping (identity transform)
+ - 1.0: Full compression effect
+
+ Returns:
+ torch.Tensor
+ The tone-mapped latent tensor of the same shape as input.
+ """
+ # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
+ scale_factor = compression * 0.75
+ abs_latents = torch.abs(latents)
+
+ # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
+ # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
+ sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
+ scales = 1.0 - 0.8 * scale_factor * sigmoid_term
+
+ filtered = latents * scales
+ return filtered
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
+ def _normalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Normalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * scaling_factor / latents_std
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
+ def _denormalize_latents(
+ latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
+ ) -> torch.Tensor:
+ # Denormalize latents across the channel dimension [B, C, F, H, W]
+ latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents * latents_std / scaling_factor + latents_mean
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
+ if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` can be provided.")
+ if video is None and latents is None:
+ raise ValueError("One of `video` or `latents` has to be provided.")
+
+ if not (0 <= tone_map_compression_ratio <= 1):
+ raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ video: Optional[List[PipelineImageInput]] = None,
+ height: int = 512,
+ width: int = 704,
+ latents: Optional[torch.Tensor] = None,
+ decode_timestep: Union[float, List[float]] = 0.0,
+ decode_noise_scale: Optional[Union[float, List[float]]] = None,
+ adain_factor: float = 0.0,
+ tone_map_compression_ratio: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ self.check_inputs(
+ video=video,
+ height=height,
+ width=width,
+ latents=latents,
+ tone_map_compression_ratio=tone_map_compression_ratio,
+ )
+
+ if video is not None:
+ # Batched video input is not yet tested/supported. TODO: take a look later
+ batch_size = 1
+ else:
+ batch_size = latents.shape[0]
+ device = self._execution_device
+
+ if video is not None:
+ num_frames = len(video)
+ if num_frames % self.vae_temporal_compression_ratio != 1:
+ num_frames = (
+ num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
+ )
+ video = video[:num_frames]
+ logger.warning(
+ f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
+ )
+ video = self.video_processor.preprocess_video(video, height=height, width=width)
+ video = video.to(device=device, dtype=torch.float32)
+
+ latents = self.prepare_latents(
+ video=video,
+ batch_size=batch_size,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ latents = self._denormalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ latents = latents.to(self.latent_upsampler.dtype)
+ latents_upsampled = self.latent_upsampler(latents)
+
+ if adain_factor > 0.0:
+ latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
+ else:
+ latents = latents_upsampled
+
+ if tone_map_compression_ratio > 0.0:
+ latents = self.tone_map_latents(latents, tone_map_compression_ratio)
+
+ if output_type == "latent":
+ latents = self._normalize_latents(
+ latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
+ )
+ video = latents
+ else:
+ if not self.vae.config.timestep_conditioning:
+ timestep = None
+ else:
+ noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
+ if not isinstance(decode_timestep, list):
+ decode_timestep = [decode_timestep] * batch_size
+ if decode_noise_scale is None:
+ decode_noise_scale = decode_timestep
+ elif not isinstance(decode_noise_scale, list):
+ decode_noise_scale = [decode_noise_scale] * batch_size
+
+ timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
+ decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
+ :, None, None, None, None
+ ]
+ latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
+
+ video = self.vae.decode(latents, timestep, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LTXPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/lucy/__init__.py b/src/diffusers/pipelines/lucy/__init__.py
new file mode 100644
index 000000000000..580e1f37f30a
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_lucy_edit import LucyEditPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
new file mode 100644
index 000000000000..69f69d5768a8
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
@@ -0,0 +1,735 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The Decart AI Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Modifications by Decart AI Team:
+# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import regex as re
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LucyPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> from typing import List
+
+ >>> import torch
+ >>> from PIL import Image
+
+ >>> from diffusers import AutoencoderKLWan, LucyEditPipeline
+ >>> from diffusers.utils import export_to_video, load_video
+
+ >>> # Arguments
+ >>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4"
+ >>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights."
+ >>> negative_prompt = ""
+ >>> num_frames = 81
+ >>> height = 480
+ >>> width = 832
+
+
+ >>> # Load video
+ >>> def convert_video(video: List[Image.Image]) -> List[Image.Image]:
+ ... video = load_video(url)[:num_frames]
+ ... video = [video[i].resize((width, height)) for i in range(num_frames)]
+ ... return video
+
+
+ >>> video = load_video(url, convert_method=convert_video)
+
+ >>> # Load model
+ >>> model_id = "decart-ai/Lucy-Edit-Dev"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # Generate video
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... video=video,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+
+ >>> # Export video
+ >>> export_to_video(output, "output.mp4", fps=24)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for video-to-video generation using Lucy Edit.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
+ two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
+ stages. If not provided, only `transformer` is used.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer", "transformer_2"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer: Optional[WanTransformer3DModel] = None,
+ transformer_2: Optional[WanTransformer3DModel] = None,
+ boundary_ratio: Optional[float] = None,
+ expand_timesteps: bool = False, # Wan2.2 ti2v
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ transformer_2=transformer_2,
+ )
+ self.register_to_config(boundary_ratio=boundary_ratio)
+ self.register_to_config(expand_timesteps=expand_timesteps)
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ video,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ guidance_scale_2=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
+
+ if video is None:
+ raise ValueError("`video` is required, received None.")
+
+ def prepare_latents(
+ self,
+ video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_latent_frames = (
+ (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
+ )
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ # Prepare noise latents
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # Prepare condition latents
+ condition_latents = [
+ retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video
+ ]
+
+ condition_latents = torch.cat(condition_latents, dim=0).to(dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, dtype
+ )
+
+ condition_latents = (condition_latents - latents_mean) * latents_std
+
+ # Check shapes
+ assert latents.shape == condition_latents.shape, (
+ f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input."
+ )
+
+ return latents, condition_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: List[Image.Image],
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[Image.Image]`):
+ The video to use as the condition for the video generation.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~LucyPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ video,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
+ self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = (
+ self.transformer.config.out_channels
+ if self.transformer is not None
+ else self.transformer_2.config.out_channels
+ )
+ video = self.video_processor.preprocess_video(video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+ latents, condition_latents = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
+
+ # latent_model_input = latents.to(transformer_dtype)
+ latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype)
+ # latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype)
+ if self.config.expand_timesteps:
+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ # batch_size, seq_len
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ timestep = t.expand(latents.shape[0])
+
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LucyPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/lucy/pipeline_output.py b/src/diffusers/pipelines/lucy/pipeline_output.py
new file mode 100644
index 000000000000..cf9ea91fd106
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class LucyPipelineOutput(BaseOutput):
+ r"""
+ Output class for Lucy pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index 816213f105cb..b59c265646cd 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Alpha-VLLM and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -372,7 +372,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -534,7 +534,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -619,7 +619,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -677,11 +677,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -689,15 +689,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -771,7 +771,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -848,7 +848,7 @@ def __call__(
# prepare image_rotary_emb for positional encoding
# dynamic scaling_factor for different resolution.
# NOTE: For `Time-aware` denosing mechanism from Lumina-Next
- # https://arxiv.org/abs/2406.18583, Sec 2.3
+ # https://huggingface.co/papers/2406.18583, Sec 2.3
# NOTE: We should compute different image_rotary_emb with different timestep.
if current_timestep[0] < scaling_watershed:
linear_factor = scaling_factor
diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
index e0905a2f131f..937803edbcbc 100644
--- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
+++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Alpha-VLLM and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -342,7 +342,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -433,6 +433,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -440,6 +446,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -448,6 +460,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -455,6 +473,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
@@ -487,7 +511,7 @@ def attention_kwargs(self):
return self._attention_kwargs
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -544,11 +568,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -556,15 +580,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/marigold/marigold_image_processing.py b/src/diffusers/pipelines/marigold/marigold_image_processing.py
index 0723014ad37b..5130a876606a 100644
--- a/src/diffusers/pipelines/marigold/marigold_image_processing.py
+++ b/src/diffusers/pipelines/marigold/marigold_image_processing.py
@@ -426,7 +426,7 @@ def visualize_depth_one(img, idx=None):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
if not torch.is_floating_point(img):
- raise ValueError(f"{prefix}: unexected dtype={img.dtype}.")
+ raise ValueError(f"{prefix}: unexpected dtype={img.dtype}.")
else:
raise ValueError(f"{prefix}: unexpected type={type(img)}.")
if val_min != 0.0 or val_max != 1.0:
@@ -464,7 +464,7 @@ def export_depth_to_16bit_png_one(img, idx=None):
if torch.is_tensor(img):
img = img.cpu().numpy()
if not np.issubdtype(img.dtype, np.floating):
- raise ValueError(f"{prefix}: unexected dtype={img.dtype}.")
+ raise ValueError(f"{prefix}: unexpected dtype={img.dtype}.")
if val_min != 0.0 or val_max != 1.0:
img = (img - val_min) / (val_max - val_min)
img = (img * (2**16 - 1)).astype(np.uint16)
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
index da991aefbd4a..92ec16fd455b 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
@@ -86,15 +86,14 @@ class MarigoldDepthOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted depth maps with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times
- width$ for `torch.Tensor` or $numimages \times height \times width \times 1$ for `np.ndarray`.
+ Predicted depth maps with values in the range [0, 1]. The shape is `numimages × 1 × height × width` for
+ `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
- \times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
- for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
+ height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
+ The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py b/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
index c809de18f469..bef9ca77c708 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
@@ -99,17 +99,17 @@ class MarigoldIntrinsicsOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted image intrinsics with values in the range [0, 1]. The shape is $(numimages * numtargets) \times 3
- \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times height \times width
- \times 3$ for `np.ndarray`, where `numtargets` corresponds to the number of predicted target modalities of
- the intrinsic image decomposition.
+ Predicted image intrinsics with values in the range [0, 1]. The shape is `(numimages * numtargets) × 3 ×
+ height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for `np.ndarray`,
+ where `numtargets` corresponds to the number of predicted target modalities of the intrinsic image
+ decomposition.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $(numimages *
- numtargets) \times 3 \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times
- height \times width \times 3$ for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `(numimages *
+ numtargets) × 3 × height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for
+ `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $(numimages * numensemble) \times (numtargets * 4) \times latentheight \times latentwidth$.
+ The shape is `(numimages * numensemble) × (numtargets * 4) × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
index 192ed590a489..485a39c995ec 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
@@ -81,15 +81,14 @@ class MarigoldNormalsOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted normals with values in the range [-1, 1]. The shape is $numimages \times 3 \times height \times
- width$ for `torch.Tensor` or $numimages \times height \times width \times 3$ for `np.ndarray`.
+ Predicted normals with values in the range [-1, 1]. The shape is `numimages × 3 × height × width` for
+ `torch.Tensor` or `numimages × height × width × 3` for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
- \times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
- for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
+ height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
+ The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py
index d1f88b02c5cc..5874a92c6f2f 100644
--- a/src/diffusers/pipelines/mochi/pipeline_mochi.py
+++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Genmo and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,11 +23,7 @@
from ...loaders import Mochi1LoraLoaderMixin
from ...models import AutoencoderKLMochi, MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -396,6 +392,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -403,6 +405,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -411,6 +419,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -418,6 +432,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -521,11 +541,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `4.5`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -534,7 +554,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -671,14 +691,15 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)
diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
index 73837af7d429..c909e5eb0d26 100644
--- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
+++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,8 +35,8 @@
logging,
replace_example_docstring,
)
-from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
+from ...utils.torch_utils import empty_device_cache, get_device, randn_tensor
+from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
if is_librosa_available():
@@ -76,7 +76,8 @@
"""
-class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
+class MusicLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
+ _last_supported_version = "0.33.1"
r"""
Pipeline for text-to-audio generation using MusicLDM.
@@ -297,7 +298,7 @@ def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -396,20 +397,22 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its
+ `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are
+ lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution
+ of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
- device = torch.device(f"cuda:{gpu_id}")
+ device_type = get_device()
+ device = torch.device(f"{device_type}:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+ empty_device_cache() # otherwise we don't see the memory savings (but they probably exist)
model_sequence = [
self.text_encoder.text_model,
@@ -472,8 +475,8 @@ def __call__(
and the input text. This scoring ranks the generated waveforms based on their cosine similarity to text
input in the joint text-audio embedding space.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -548,7 +551,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index 5fe5be3b26d2..090cb46aace4 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -1,4 +1,4 @@
-# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,12 +23,14 @@
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGenTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
-from .processor_omnigen import OmniGenMultiModalProcessor
+if is_torchvision_available():
+ from .processor_omnigen import OmniGenMultiModalProcessor
+
if is_torch_xla_available():
XLA_AVAILABLE = True
else:
@@ -120,7 +122,7 @@ class OmniGenPipeline(
r"""
The OmniGen pipeline for multimodal-to-image generation.
- Reference: https://arxiv.org/pdf/2409.11340
+ Reference: https://huggingface.co/papers/2409.11340
Args:
transformer ([`OmniGenTransformer2DModel`]):
@@ -176,7 +178,7 @@ def encode_input_images(
get the continue embedding of input images by VAE
Args:
- input_pixel_values: normlized pixel of input images
+ input_pixel_values: normalized pixel of input images
device:
Returns: torch.Tensor
"""
@@ -233,6 +235,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -240,6 +248,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -248,6 +262,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -255,6 +275,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
@@ -346,13 +372,13 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 2.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
img_guidance_scale (`float`, *optional*, defaults to 1.6):
- Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
+ Defined as equation 3 in [Instrucpix2pix](https://huggingface.co/papers/2211.09800).
use_input_image_size_as_output (bool, defaults to False):
whether to use the input image size as the output image size, which can be used for single-image input,
e.g., image editing task
@@ -364,7 +390,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
index 75d272ac5140..7ed11871bb2a 100644
--- a/src/diffusers/pipelines/omnigen/processor_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -1,4 +1,4 @@
-# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,12 @@
import numpy as np
import torch
from PIL import Image
-from torchvision import transforms
+
+from ...utils import is_torchvision_available
+
+
+if is_torchvision_available():
+ from torchvision import transforms
def crop_image(pil_image, max_image_size):
@@ -95,13 +100,13 @@ def process_multi_modal_prompt(self, text, input_images):
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(set(image_ids))
- assert unique_image_ids == list(
- range(1, len(unique_image_ids) + 1)
- ), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
+ f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
+ )
# total images must be the same as the number of image tags
- assert (
- len(unique_image_ids) == len(input_images)
- ), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+ assert len(unique_image_ids) == len(input_images), (
+ f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
+ )
input_images = [input_images[x - 1] for x in image_ids]
@@ -198,7 +203,7 @@ def create_position(self, attention_mask, num_tokens_for_output_images):
def create_mask(self, attention_mask, num_tokens_for_output_images):
"""
OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within
- each image sequence References: [OmniGen](https://arxiv.org/pdf/2409.11340)
+ each image sequence References: [OmniGen](https://huggingface.co/papers/2409.11340)
"""
extended_mask = []
padding_images = []
diff --git a/src/diffusers/pipelines/onnx_utils.py b/src/diffusers/pipelines/onnx_utils.py
index 0e12340f6895..74e9f0b97800 100644
--- a/src/diffusers/pipelines/onnx_utils.py
+++ b/src/diffusers/pipelines/onnx_utils.py
@@ -75,6 +75,11 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None, provide
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
+ if provider_options is None:
+ provider_options = []
+ elif not isinstance(provider_options, list):
+ provider_options = [provider_options]
+
return ort.InferenceSession(
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
)
@@ -174,7 +179,10 @@ def _from_pretrained(
# load model from local directory
if os.path.isdir(model_id):
model = OnnxRuntimeModel.load_model(
- Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
+ Path(model_id, model_file_name).as_posix(),
+ provider=provider,
+ sess_options=sess_options,
+ provider_options=kwargs.pop("provider_options"),
)
kwargs["model_save_dir"] = Path(model_id)
# load model from hub
@@ -190,7 +198,12 @@ def _from_pretrained(
)
kwargs["model_save_dir"] = Path(model_cache_path).parent
kwargs["latest_model_name"] = Path(model_cache_path).name
- model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
+ model = OnnxRuntimeModel.load_model(
+ model_cache_path,
+ provider=provider,
+ sess_options=sess_options,
+ provider_options=kwargs.pop("provider_options"),
+ )
return cls(model=model, **kwargs)
@classmethod
diff --git a/src/diffusers/pipelines/ovis_image/__init__.py b/src/diffusers/pipelines/ovis_image/__init__.py
new file mode 100644
index 000000000000..275061b1f626
--- /dev/null
+++ b/src/diffusers/pipelines/ovis_image/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa: F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_output"] = ["OvisImagePipelineOutput"]
+ _import_structure["pipeline_ovis_image"] = ["OvisImagePipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_output import OvisImagePipelineOutput
+ from .pipeline_ovis_image import OvisImagePipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/ovis_image/pipeline_output.py b/src/diffusers/pipelines/ovis_image/pipeline_output.py
new file mode 100644
index 000000000000..160c5b73a917
--- /dev/null
+++ b/src/diffusers/pipelines/ovis_image/pipeline_output.py
@@ -0,0 +1,35 @@
+# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class OvisImagePipelineOutput(BaseOutput):
+ """
+ Output class for Ovis-Image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py
new file mode 100644
index 000000000000..94d6cee93d7e
--- /dev/null
+++ b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py
@@ -0,0 +1,668 @@
+# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2TokenizerFast, Qwen3Model
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, OvisImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import OvisImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import OvisImagePipeline
+
+ >>> pipe = OvisImagePipeline.from_pretrained("AIDC-AI/Ovis-Image-7B", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = 'A creative 3D artistic render where the text "OVIS-IMAGE" is written in a bold, expressive handwritten brush style using thick, wet oil paint. The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste or impasto art. You can see the ridges of the brush bristles and the glossy, wet texture of the paint. The background is a clean artist\'s canvas. Dynamic lighting creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile texture, 4k detail.'
+ >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50, guidance_scale=5.0).images[0]
+ >>> image.save("ovis_image.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class OvisImagePipeline(
+ DiffusionPipeline,
+):
+ r"""
+ The Ovis-Image pipeline for text-to-image generation.
+
+ Reference: https://github.com/AIDC-AI/Ovis-Image
+
+ Args:
+ transformer ([`OvisImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen3Model`]):
+ Text encoder of class
+ [Qwen3Model](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3Model).
+ tokenizer (`Qwen2TokenizerFast`):
+ Tokenizer of class
+ [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: Qwen3Model,
+ tokenizer: Qwen2TokenizerFast,
+ transformer: OvisImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Ovis-Image latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.system_prompt = "Describe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: "
+ self.user_prompt_begin_id = 28
+ self.tokenizer_max_length = 256 + self.user_prompt_begin_id
+ self.default_sample_size = 128
+
+ def _get_messages(
+ self,
+ prompt: Union[str, List[str]] = None,
+ ):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ messages = []
+ for each_prompt in prompt:
+ message = [
+ {
+ "role": "user",
+ "content": self.system_prompt + each_prompt,
+ }
+ ]
+ message = self.tokenizer.apply_chat_template(
+ message, tokenize=False, add_generation_prompt=True, enable_thinking=False
+ )
+ messages.append(message)
+ return messages
+
+ def _get_ovis_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ messages = self._get_messages(prompt)
+ batch_size = len(messages)
+
+ tokens = self.tokenizer(
+ messages,
+ padding="max_length",
+ truncation=True,
+ max_length=self.tokenizer_max_length,
+ return_tensors="pt",
+ add_special_tokens=False,
+ )
+ input_ids = tokens.input_ids.to(device)
+ attention_mask = tokens.attention_mask.to(device)
+ outputs = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = outputs.last_hidden_state
+ prompt_embeds = prompt_embeds * attention_mask[..., None]
+ prompt_embeds = prompt_embeds[:, self.user_prompt_begin_id :, :]
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_ovis_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3)
+ text_ids[..., 1] = text_ids[..., 1] + torch.arange(prompt_embeds.shape[1])[None, :]
+ text_ids[..., 2] = text_ids[..., 2] + torch.arange(prompt_embeds.shape[1])[None, :]
+ text_ids = text_ids.to(device=device, dtype=dtype)
+ return prompt_embeds, text_ids
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 256:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = "",
+ guidance_scale: float = 5.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 256,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ not greater than `1`).
+ guidance_scale (`float`, *optional*, defaults to 1.0):
+ True classifier-free guidance (guidance scale) is enabled when `guidance_scale` > 1 and
+ `negative_prompt` is provided.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ovis_image.OvisImagePipelineOutput`] or `tuple`:
+ [`~pipelines.ovis_image.OvisImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ do_classifier_free_guidance = guidance_scale > 1
+ (
+ prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ if do_classifier_free_guidance:
+ (
+ negative_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
+ sigmas = None
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ # 6. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+
+ if do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+ noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return OvisImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py
index 4cd2fe4cb79f..8a56961f321c 100644
--- a/src/diffusers/pipelines/pag/pag_utils.py
+++ b/src/diffusers/pipelines/pag/pag_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@
class PAGMixin:
- r"""Mixin class for [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377v1)."""
+ r"""Mixin class for [Pertubed Attention Guidance](https://huggingface.co/papers/2403.17377v1)."""
def _set_pag_attn_processor(self, pag_applied_layers, do_classifier_free_guidance):
r"""
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
index bc90073cba77..1abef014301a 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -80,7 +80,10 @@
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = AutoPipelineForText2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ ... controlnet=controlnet,
+ ... torch_dtype=torch.float16,
+ ... enable_pag=True,
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
@@ -202,8 +205,8 @@ class StableDiffusionControlNetPAGPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -541,7 +544,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -843,7 +846,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -933,8 +936,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1228,7 +1231,11 @@ def __call__(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ if (
+ torch.cuda.is_available()
+ and (is_unet_compiled and is_controlnet_compiled)
+ and is_torch_higher_equal_2_1
+ ):
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
@@ -1309,7 +1316,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
index bc7a4b57affd..2781af789018 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,7 +37,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -93,7 +93,10 @@
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
... )
>>> pipe = AutoPipelineForInpainting.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ ... controlnet=controlnet,
+ ... torch_dtype=torch.float16,
+ ... enable_pag=True,
... )
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
@@ -150,17 +153,14 @@ class StableDiffusionControlNetPAGInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
-
-
- This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
- ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as
- default text-to-image Stable Diffusion checkpoints
- ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image
- Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as
+ > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
+ ([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
+ as well as > default text-to-image Stable Diffusion checkpoints >
+ ([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
+ Default text-to-image > Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned
+ on those, such as >
[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
-
-
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -179,8 +179,8 @@ class StableDiffusionControlNetPAGInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -520,7 +520,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -604,7 +604,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -612,7 +612,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
@@ -955,7 +955,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1064,8 +1064,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1332,7 +1332,7 @@ def __call__(
# 7.1 Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
@@ -1340,7 +1340,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1521,7 +1521,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
index 83540885bfb2..381352ccc5d4 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,14 +39,11 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -619,7 +616,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -936,21 +933,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -992,7 +980,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1111,8 +1099,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1498,7 +1486,11 @@ def __call__(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ if (
+ torch.cuda.is_available()
+ and (is_unet_compiled and is_controlnet_compiled)
+ and is_torch_higher_equal_2_1
+ ):
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
index b84f5d555914..df5b3f5c10a5 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,20 +39,17 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import is_compiled_module, randn_tensor
+from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from .pag_utils import PAGMixin
@@ -115,7 +112,7 @@
... with torch.no_grad(), torch.autocast("cuda"):
... depth_map = depth_estimator(image).predicted_depth
- ... depth_map = torch.nn.fuctional.interpolate(
+ ... depth_map = torch.nn.functional.interpolate(
... depth_map.unsqueeze(1),
... size=(1024, 1024),
... mode="bicubic",
@@ -611,7 +608,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -926,7 +923,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
image = image.to(device=device, dtype=dtype)
@@ -1049,21 +1046,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@property
def guidance_scale(self):
@@ -1074,7 +1062,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1176,11 +1164,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1191,15 +1179,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1648,7 +1636,7 @@ def __call__(
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
index a6a8deb5883c..d156eac8f3f7 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,11 +28,7 @@
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pag_utils import PAGMixin
@@ -131,7 +127,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -443,7 +439,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -566,7 +562,7 @@ def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -638,8 +634,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -675,7 +671,7 @@ def __call__(
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
- Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
@@ -915,7 +911,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
index 62f634312ada..1403be03a620 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI, Kwai-Kolors Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,9 +21,8 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..kolors.pipeline_output import KolorsPipelineOutput
from ..kolors.text_encoder import ChatGLMModel
@@ -453,7 +452,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -598,22 +597,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -651,7 +640,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -749,11 +738,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -761,15 +750,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
index affda7e18add..9031877b5b8d 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
@@ -1,4 +1,4 @@
-# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -326,7 +326,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -488,7 +488,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -624,11 +624,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -636,15 +636,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -729,7 +729,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
index 030ab6db7391..9e91ccbe8006 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
@@ -1,4 +1,4 @@
-# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,13 +29,14 @@
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
+ deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
)
-from ...utils.torch_utils import randn_tensor
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_512_BIN,
@@ -190,6 +191,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -197,6 +204,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -205,6 +218,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -212,6 +231,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def encode_prompt(
@@ -363,7 +388,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -524,7 +549,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -683,11 +708,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -695,15 +720,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -917,9 +942,15 @@ def __call__(
image = latents
else:
latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
- except torch.cuda.OutOfMemoryError as e:
+ except oom_error as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
index fc7dc3a83f27..ea64f8be2c50 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@
>>> from diffusers import AutoPipelineForText2Image
>>> pipe = AutoPipelineForText2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
... )
>>> pipe = pipe.to("cuda")
@@ -72,7 +72,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -190,8 +190,8 @@ class StableDiffusionPAGPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -272,8 +272,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -573,7 +573,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -724,7 +724,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -802,8 +802,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -833,7 +833,7 @@ def __call__(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -1027,7 +1027,7 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
index fde3e500a573..941b675099b9 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -237,7 +237,7 @@ def _get_t5_prompt_embeds(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -326,7 +326,7 @@ def _get_clip_prompt_embeds(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -663,7 +663,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -738,11 +738,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -761,7 +761,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
index d64582a26f7a..f40dd52fc244 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -253,7 +253,7 @@ def _get_t5_prompt_embeds(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -342,7 +342,7 @@ def _get_clip_prompt_embeds(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -714,7 +714,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -799,11 +799,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -822,7 +822,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
index d3a015e569c1..de13be9c4d22 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -438,7 +438,7 @@ def decode_latents(self, latents, decode_chunk_size: int = 16):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -520,7 +520,7 @@ def check_inputs(
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
- # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://arxiv.org/abs/2310.15169)
+ # If FreeNoise is enabled, generate latents as described in Equation (7) of [FreeNoise](https://huggingface.co/papers/2310.15169)
if self.free_noise_enabled:
latents = self._prepare_latents_free_noise(
batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents
@@ -558,7 +558,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -624,8 +624,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
index d91c02b607a3..8351112ce409 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,7 +61,7 @@
>>> from diffusers.utils import load_image
>>> pipe = AutoPipelineForImage2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5",
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
... torch_dtype=torch.float16,
... enable_pag=True,
... )
@@ -185,8 +185,8 @@ class StableDiffusionPAGImg2ImgPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -267,8 +267,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -568,7 +568,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -761,7 +761,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -847,8 +847,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
index 33abfb0be89f..6b1b294e10f5 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -58,7 +58,7 @@
>>> from diffusers import AutoPipelineForInpainting
>>> pipe = AutoPipelineForInpainting.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
... )
>>> pipe = pipe.to("cuda")
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
@@ -99,7 +99,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -217,8 +217,8 @@ class StableDiffusionPAGInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -299,8 +299,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -603,7 +603,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -683,7 +683,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -691,7 +691,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -889,7 +889,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -972,8 +972,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1003,7 +1003,7 @@ def __call__(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -1183,7 +1183,7 @@ def __call__(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
@@ -1191,7 +1191,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1294,7 +1294,7 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
index 856f6a3e789e..a69f06536a55 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,15 +32,11 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- FusedAttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
@@ -91,7 +87,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -607,7 +603,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -762,22 +758,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -823,7 +809,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -925,11 +911,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -940,15 +926,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -981,9 +967,10 @@ def __call__(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1266,7 +1253,7 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
index 93dcca0ea9d6..416d9e5677b4 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,14 +34,11 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
@@ -49,7 +46,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import randn_tensor
+from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from .pag_utils import PAGMixin
@@ -95,7 +92,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -553,7 +550,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -716,7 +713,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
image = image.to(device=device, dtype=dtype)
@@ -910,21 +907,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -970,7 +958,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1088,11 +1076,11 @@ def __call__(
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1103,15 +1091,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1144,9 +1132,10 @@ def __call__(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1461,7 +1450,7 @@ def denoising_value_valid(dnv):
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index fdf3df2f4d6a..6be341e07b1a 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,14 +34,11 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
@@ -108,7 +105,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -643,7 +640,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -737,7 +734,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -745,7 +742,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1001,21 +998,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -1061,7 +1049,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1208,11 +1196,11 @@ def __call__(
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1243,15 +1231,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1501,7 +1489,7 @@ def denoising_value_valid(dnv):
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
@@ -1509,7 +1497,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1673,7 +1661,7 @@ def denoising_value_valid(dnv):
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/paint_by_example/image_encoder.py b/src/diffusers/pipelines/paint_by_example/image_encoder.py
index 2fd0338b1f91..74c575ed8653 100644
--- a/src/diffusers/pipelines/paint_by_example/image_encoder.py
+++ b/src/diffusers/pipelines/paint_by_example/image_encoder.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 55a9f47145a2..c09992befbcb 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from .image_encoder import PaintByExampleImageEncoder
@@ -155,13 +155,10 @@ def prepare_mask_and_masked_image(image, mask):
return mask, masked_image
-class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
+class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
+ _last_supported_version = "0.33.1"
r"""
-
-
- 🧪 This is an experimental feature!
-
-
+ > [!WARNING] > 🧪 This is an experimental feature!
Pipeline for image-guided image inpainting using Stable Diffusion.
@@ -182,8 +179,8 @@ class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -239,7 +236,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -447,8 +444,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -521,7 +518,7 @@ def __call__(
batch_size = image.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -575,7 +572,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py
index df8499ab900a..dfc6e83fbd7c 100644
--- a/src/diffusers/pipelines/pia/pipeline_pia.py
+++ b/src/diffusers/pipelines/pia/pipeline_pia.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,7 +46,7 @@
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
if is_torch_xla_available():
@@ -132,6 +132,7 @@ class PIAPipelineOutput(BaseOutput):
class PIAPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
@@ -140,6 +141,7 @@ class PIAPipeline(
FromSingleFileMixin,
FreeInitMixin,
):
+ _last_supported_version = "0.33.1"
r"""
Pipeline for text-to-video generation.
@@ -432,7 +434,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -653,7 +655,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -723,8 +725,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py
index ec2f82bcf742..2724c764c771 100644
--- a/src/diffusers/pipelines/pipeline_flax_utils.py
+++ b/src/diffusers/pipelines/pipeline_flax_utils.py
@@ -248,9 +248,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
pretrained pipeline hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
using [`~FlaxDiffusionPipeline.save_pretrained`].
- dtype (`str` or `jnp.dtype`, *optional*):
- Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is
- automatically derived from the model's weights.
+ dtype (`jnp.dtype`, *optional*):
+ Override the default `jnp.dtype` and load the model under this dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -277,12 +276,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Can be used to overwrite load and saveable variables (the pipeline components) of the specific pipeline
class. The overwritten components are passed directly to the pipelines `__init__` method.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`.
-
-
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`.
Examples:
@@ -313,6 +308,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> dpm_params["scheduler"] = dpmpp_state
```
"""
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
@@ -469,7 +469,7 @@ def load_module(name, value):
class_obj = import_flax_or_no_model(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
- class_candidates = {c: class_obj for c in importable_classes.keys()}
+ class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index f5b430564ca1..8868e942ce3d 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -19,12 +19,12 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
+import httpx
import requests
import torch
from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
-from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
+from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version
-from requests.exceptions import HTTPError
from .. import __version__
from ..utils import (
@@ -33,6 +33,7 @@
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
+ _maybe_remap_transformers_class,
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
@@ -48,10 +49,12 @@
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase
- from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
+ if is_transformers_version("<=", "4.56.2"):
+ from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
+
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
@@ -73,6 +76,7 @@
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
+ "BaseGuidance": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
@@ -92,7 +96,7 @@
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
-def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
+def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -103,6 +107,33 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
+ weight_names = [
+ WEIGHTS_NAME,
+ SAFETENSORS_WEIGHTS_NAME,
+ FLAX_WEIGHTS_NAME,
+ ONNX_WEIGHTS_NAME,
+ ONNX_EXTERNAL_WEIGHTS_NAME,
+ ]
+
+ if is_transformers_available():
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
+
+ # model_pytorch, diffusion_model_pytorch, ...
+ weight_prefixes = [w.split(".")[0] for w in weight_names]
+ # .bin, .safetensors, ...
+ weight_suffixs = [w.split(".")[-1] for w in weight_names]
+ # -00001-of-00002
+ transformers_index_format = r"\d{5}-of-\d{5}"
+ # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
+ variant_file_re = re.compile(
+ rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
+ )
+ non_variant_file_re = re.compile(
+ rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
+ )
+
passed_components = passed_components or []
if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
@@ -121,15 +152,29 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components[component].append(component_filename)
# If there are no component folders check the main directory for safetensors files
+ filtered_filenames = set()
if not components:
- return any(".safetensors" in filename for filename in filenames)
+ if variant is not None:
+ filtered_filenames = filter_with_regex(filenames, variant_file_re)
+
+ # If no variant filenames exist check if non-variant files are available
+ if not filtered_filenames:
+ filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
+ return any(".safetensors" in filename for filename in filtered_filenames)
# iterate over all files of a component
# check if safetensor files exist for that component
- # if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items():
matches = []
- for component_filename in component_filenames:
+ filtered_component_filenames = set()
+ # if variant is provided check if the variant of the safetensors exists
+ if variant is not None:
+ filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
+
+ # if variant safetensor files do not exist check for non-variants
+ if not filtered_component_filenames:
+ filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
+ for component_filename in filtered_component_filenames:
filename, extension = os.path.splitext(component_filename)
match_exists = extension == ".safetensors"
@@ -152,13 +197,19 @@ def filter_model_files(filenames):
]
if is_transformers_available():
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
+def filter_with_regex(filenames, pattern_re):
+ return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
+
+
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
@@ -169,7 +220,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
]
if is_transformers_available():
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -207,9 +260,6 @@ def filter_for_compatible_extensions(filenames, ignore_patterns=None):
# interested in the extension name
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
- def filter_with_regex(filenames, pattern_re):
- return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
-
# Group files by component
components = {}
for filename in filenames:
@@ -308,6 +358,11 @@ def maybe_raise_or_warn(
"""Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module:
library = importlib.import_module(library_name)
+
+ # Handle deprecated Transformers classes
+ if library_name == "transformers":
+ class_name = _maybe_remap_transformers_class(class_name) or class_name
+
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
@@ -323,9 +378,7 @@ def maybe_raise_or_warn(
model_cls = unwrapped_sub_model.__class__
if not issubclass(model_cls, expected_class_obj):
- raise ValueError(
- f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}"
- )
+ raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@@ -333,27 +386,52 @@ def maybe_raise_or_warn(
)
+# a simpler version of get_class_obj_and_candidates, it won't work with custom code
+def simple_get_class_obj(library_name, class_name):
+ from diffusers import pipelines
+
+ is_pipeline_module = hasattr(pipelines, library_name)
+
+ if is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = getattr(pipeline_module, class_name)
+ else:
+ library = importlib.import_module(library_name)
+
+ # Handle deprecated Transformers classes
+ if library_name == "transformers":
+ class_name = _maybe_remap_transformers_class(class_name) or class_name
+
+ class_obj = getattr(library, class_name)
+
+ return class_obj
+
+
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
- component_folder = os.path.join(cache_dir, component_name)
+ component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
- class_candidates = {c: class_obj for c in importable_classes.keys()}
- elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
+ class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
+ elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
- class_candidates = {c: class_obj for c in importable_classes.keys()}
+ class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
+ # Handle deprecated Transformers classes
+ if library_name == "transformers":
+ class_name = _maybe_remap_transformers_class(class_name) or class_name
+
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
@@ -414,7 +492,7 @@ def _get_pipeline_class(
revision=revision,
)
- if class_obj.__name__ != "DiffusionPipeline":
+ if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
return class_obj
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
@@ -559,6 +637,9 @@ def _assign_components_to_devices(
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
+ # TODO: separate out different device_map methods when it gets to it.
+ if device_map != "balanced":
+ return device_map
# To avoid circular import problem.
from diffusers import pipelines
@@ -677,8 +758,10 @@ def load_sub_model(
use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
+ quantization_config: Optional[Any] = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
+ from ..quantizers import PipelineQuantizationConfig
# retrieve class candidates
@@ -771,6 +854,20 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
+ if is_transformers_model and is_transformers_version(">=", "4.57.0"):
+ loading_kwargs.pop("offload_state_dict")
+
+ if (
+ quantization_config is not None
+ and isinstance(quantization_config, PipelineQuantizationConfig)
+ and issubclass(class_obj, torch.nn.Module)
+ ):
+ model_quant_config = quantization_config._resolve_quant_config(
+ is_diffusers=is_diffusers_model, module_name=name
+ )
+ if model_quant_config is not None:
+ loading_kwargs["quantization_config"] = model_quant_config
+
# check if the module is in a subdirectory
if dduf_entries:
loading_kwargs["dduf_entries"] = dduf_entries
@@ -785,6 +882,9 @@ def load_sub_model(
# remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu"
+ skip_keys = None
+ if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
+ skip_keys = loaded_sub_model._skip_keys
if needs_offloading_to_cpu:
dispatch_model(
@@ -793,9 +893,10 @@ def load_sub_model(
device_map=device_map,
force_hooks=True,
main_device=0,
+ skip_keys=skip_keys,
)
else:
- dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
+ dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
return loaded_sub_model
@@ -841,7 +942,10 @@ def _fetch_class_library_tuple(module):
library = not_compiled_module.__module__
# retrieve class_name
- class_name = not_compiled_module.__class__.__name__
+ if isinstance(not_compiled_module, type):
+ class_name = not_compiled_module.__name__
+ else:
+ class_name = not_compiled_module.__class__.__name__
return (library, class_name)
@@ -986,7 +1090,7 @@ def _get_ignore_patterns(
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
- model_filenames, passed_components=passed_components, folder_names=model_folder_names
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
)
):
raise EnvironmentError(
@@ -997,7 +1101,7 @@ def _get_ignore_patterns(
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(
- model_filenames, passed_components=passed_components, folder_names=model_folder_names
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
):
ignore_patterns = ["*.bin", "*.msgpack"]
@@ -1029,7 +1133,7 @@ def _download_dduf_file(
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
- except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
+ except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
@@ -1080,3 +1184,26 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
break
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
+
+
+def _maybe_warn_for_wrong_component_in_quant_config(pipe_init_dict, quant_config):
+ if quant_config is None:
+ return
+
+ actual_pipe_components = set(pipe_init_dict.keys())
+ missing = ""
+ quant_components = None
+ if getattr(quant_config, "components_to_quantize", None) is not None:
+ quant_components = set(quant_config.components_to_quantize)
+ elif getattr(quant_config, "quant_mapping", None) is not None and isinstance(quant_config.quant_mapping, dict):
+ quant_components = set(quant_config.quant_mapping.keys())
+
+ if quant_components and not quant_components.issubset(actual_pipe_components):
+ missing = quant_components - actual_pipe_components
+
+ if missing:
+ logger.warning(
+ f"The following components in the quantization config {missing} will be ignored "
+ "as they do not belong to the underlying pipeline. Acceptable values for the pipeline "
+ f"components are: {', '.join(actual_pipe_components)}."
+ )
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 66b56740ef13..392d5fb3feb4 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -23,6 +23,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
+import httpx
import numpy as np
import PIL.Image
import requests
@@ -36,9 +37,8 @@
read_dduf_file,
snapshot_download,
)
-from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
+from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version
-from requests.exceptions import HTTPError
from tqdm.auto import tqdm
from typing_extensions import Self
@@ -47,6 +47,7 @@
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
+from ..quantizers import PipelineQuantizationConfig
from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
@@ -56,8 +57,10 @@
PushToHubMixin,
_get_detailed_type,
_is_valid_type,
+ deprecate,
is_accelerate_available,
is_accelerate_version,
+ is_hpu_available,
is_torch_npu_available,
is_torch_version,
is_transformers_version,
@@ -65,7 +68,7 @@
numpy_to_pil,
)
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
-from ..utils.torch_utils import is_compiled_module
+from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
if is_torch_npu_available():
@@ -86,6 +89,7 @@
_identify_model_variants,
_maybe_raise_error_for_incorrect_transformers,
_maybe_raise_warning_for_inpainting,
+ _maybe_warn_for_wrong_component_in_quant_config,
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
@@ -105,7 +109,7 @@
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
-SUPPORTED_DEVICE_MAP = ["balanced"]
+SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
logger = logging.get_logger(__name__)
@@ -137,6 +141,43 @@ class AudioPipelineOutput(BaseOutput):
audios: np.ndarray
+class DeprecatedPipelineMixin:
+ """
+ A mixin that can be used to mark a pipeline as deprecated.
+
+ Pipelines inheriting from this mixin will raise a warning when instantiated, indicating that they are deprecated
+ and won't receive updates past the specified version. Tests will be skipped for pipelines that inherit from this
+ mixin.
+
+ Example usage:
+ ```python
+ class MyDeprecatedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
+ _last_supported_version = "0.20.0"
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ ```
+ """
+
+ # Override this in the inheriting class to specify the last version that will support this pipeline
+ _last_supported_version = None
+
+ def __init__(self, *args, **kwargs):
+ # Get the class name for the warning message
+ class_name = self.__class__.__name__
+
+ # Get the last supported version or use the current version if not specified
+ version_info = getattr(self.__class__, "_last_supported_version", __version__)
+
+ # Raise a warning that this pipeline is deprecated
+ logger.warning(
+ f"The {class_name} has been deprecated and will not receive bug fixes or feature updates after Diffusers version {version_info}. "
+ )
+
+ # Call the parent class's __init__ method
+ super().__init__(*args, **kwargs)
+
+
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
r"""
Base class for all pipelines.
@@ -331,12 +372,8 @@ def to(self, *args, **kwargs) -> Self:
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
-
-
- If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
- the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
-
-
+ > [!TIP] > If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is.
+ Otherwise, > the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
Here are the ways to call `to`:
@@ -404,6 +441,11 @@ def module_is_sequentially_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
+
+ if is_loaded_in_8bit_bnb:
+ return False
+
return hasattr(module, "_hf_hook") and (
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
or hasattr(module._hf_hook, "hooks")
@@ -445,6 +487,27 @@ def module_is_offloaded(module):
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
+ # Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
+ if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
+ os.environ["PT_HPU_GPU_MIGRATION"] = "1"
+ logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
+
+ import habana_frameworks.torch # noqa: F401
+
+ # HPU hardware check
+ if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
+ raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
+
+ os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
+ logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
+
+ if dtype in (torch.bfloat16, None) and kwargs.pop("sdp_on_bf16", True):
+ if hasattr(torch._C, "_set_math_sdp_allow_fp16_bf16_reduction"):
+ torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
+ logger.warning(
+ "Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`"
+ )
+
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
@@ -552,19 +615,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
saved using
[`~DiffusionPipeline.save_pretrained`].
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
- torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
- dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
- `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
- unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
- torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
+ torch_dtype (`torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
+ Override the default `torch.dtype` and load the model with another dtype. To load submodels with
+ different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`).
+ Set the default dtype for unspecified components with `default` (for example `{'transformer':
+ torch.bfloat16, 'default': torch.float16}`). If a component is not specified and no default is set,
+ `torch.float32` is used.
custom_pipeline (`str`, *optional*):
-
-
- 🧪 This is an experimental feature and may change in the future.
-
-
+ > [!WARNING] > 🧪 This is an experimental feature and may change in the future.
Can be either:
@@ -611,14 +670,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
- device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
- A map that specifies where each submodule should go. It doesn’t need to be defined for each
- parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
- same device.
-
- Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
- more information about each option see [designing a device
- map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ device_map (`str`, *optional*):
+ Strategy that dictates how the different components of a pipeline should be placed on available
+ devices. Currently, only "balanced" `device_map` is supported. Check out
+ [this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement)
+ to know more.
max_memory (`Dict`, *optional*):
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
each GPU and the available CPU RAM if unset.
@@ -652,12 +708,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dduf_file(`str`, *optional*):
Load weights from the specified dduf file.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -705,6 +757,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
+ quantization_config = kwargs.pop("quantization_config", None)
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
@@ -721,6 +774,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" install accelerate\n```\n."
)
+ if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
+ raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
@@ -925,14 +981,18 @@ def load_module(name, value):
# 7. Load each module in the pipeline
current_device_map = None
+ _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans
- if final_device_map is not None and len(final_device_map) > 0:
- component_device = final_device_map.get(name, None)
- if component_device is not None:
- current_device_map = {"": component_device}
- else:
- current_device_map = None
+ if final_device_map is not None:
+ if isinstance(final_device_map, dict) and len(final_device_map) > 0:
+ component_device = final_device_map.get(name, None)
+ if component_device is not None:
+ current_device_map = {"": component_device}
+ else:
+ current_device_map = None
+ elif isinstance(final_device_map, str):
+ current_device_map = final_device_map
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
@@ -981,6 +1041,7 @@ def load_module(name, value):
use_safetensors=use_safetensors,
dduf_entries=dduf_entries,
provider_options=provider_options,
+ quantization_config=quantization_config,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1034,6 +1095,8 @@ def load_module(name, value):
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None:
setattr(model, "hf_device_map", final_device_map)
+ if quantization_config is not None:
+ setattr(model, "quantization_config", quantization_config)
return model
@property
@@ -1084,19 +1147,20 @@ def remove_all_hooks(self):
accelerate.hooks.remove_hook_from_module(model, recurse=True)
self._all_hooks = []
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the accelerator when its
+ `forward` method is called, and the model remains in accelerator until the next model runs. Memory savings are
+ lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution
+ of the `unet`.
Arguments:
gpu_id (`int`, *optional*):
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
- device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
+ device (`torch.Device` or `str`, *optional*, defaults to None):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
- default to "cuda".
+ automatically detect the available accelerator and use.
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
@@ -1118,6 +1182,11 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
self.remove_all_hooks()
+ if device is None:
+ device = get_device()
+ if device == "cpu":
+ raise RuntimeError("`enable_model_cpu_offload` requires accelerator, but not found")
+
torch_device = torch.device(device)
device_index = torch_device.index
@@ -1135,9 +1204,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
self._offload_device = device
self.to("cpu", silence_dtype_warnings=True)
- device_mod = getattr(torch, device.type, None)
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
- device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+ empty_device_cache(device.type)
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
@@ -1196,20 +1263,20 @@ def maybe_free_model_hooks(self):
# make sure the model is in the same state as before calling it
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using 🤗 Accelerate, significantly reducing memory usage. When called, the state
dicts of all `torch.nn.Module` components (except those in `self._exclude_from_cpu_offload`) are saved to CPU
- and then moved to `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward`
- method called. Offloading happens on a submodule basis. Memory savings are higher than with
+ and then moved to `torch.device('meta')` and loaded to accelerator only when their specific submodule has its
+ `forward` method called. Offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
Arguments:
gpu_id (`int`, *optional*):
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
- device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
+ device (`torch.Device` or `str`, *optional*, defaults to None):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
- default to "cuda".
+ automatically detect the available accelerator and use.
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
@@ -1225,6 +1292,11 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
)
+ if device is None:
+ device = get_device()
+ if device == "cpu":
+ raise RuntimeError("`enable_sequential_cpu_offload` requires accelerator, but not found")
+
torch_device = torch.device(device)
device_index = torch_device.index
@@ -1242,10 +1314,9 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
self._offload_device = device
if self.device.type != "cpu":
+ orig_device_type = self.device.type
self.to("cpu", silence_dtype_warnings=True)
- device_mod = getattr(torch, self.device.type, None)
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
- device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+ empty_device_cache(orig_device_type)
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module):
@@ -1259,6 +1330,133 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
+ def enable_group_offload(
+ self,
+ onload_device: torch.device,
+ offload_device: torch.device = torch.device("cpu"),
+ offload_type: str = "block_level",
+ num_blocks_per_group: Optional[int] = None,
+ non_blocking: bool = False,
+ use_stream: bool = False,
+ record_stream: bool = False,
+ low_cpu_mem_usage=False,
+ offload_to_disk_path: Optional[str] = None,
+ exclude_modules: Optional[Union[str, List[str]]] = None,
+ ) -> None:
+ r"""
+ Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
+ and where it is beneficial, we need to first provide some context on how other supported offloading methods
+ work.
+
+ Typically, offloading is done at two levels:
+ - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
+ works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator
+ device when needed for computation. This method is more memory-efficient than keeping all components on the
+ accelerator, but the memory requirements are still quite high. For this method to work, one needs memory
+ equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able
+ to complete the forward pass.
+ - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method.
+ It
+ works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
+ onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
+ memory, but can be slower due to the excessive number of device synchronizations.
+
+ Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
+ (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
+ offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations
+ is reduced.
+
+ Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability
+ to overlap data transfer and computation to reduce the overall execution time compared to sequential
+ offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next
+ starts onloading to the accelerator device while the current layer is being executed - this increases the
+ memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made
+ much faster when using streams.
+
+ Args:
+ onload_device (`torch.device`):
+ The device to which the group of modules are onloaded.
+ offload_device (`torch.device`, defaults to `torch.device("cpu")`):
+ The device to which the group of modules are offloaded. This should typically be the CPU. Default is
+ CPU.
+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
+ The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
+ "block_level".
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in
+ limited RAM environment settings where a reasonable speed-memory trade-off is desired.
+ num_blocks_per_group (`int`, *optional*):
+ The number of blocks per group when using offload_type="block_level". This is required when using
+ offload_type="block_level".
+ non_blocking (`bool`, defaults to `False`):
+ If True, offloading and onloading is done with non-blocking data transfer.
+ use_stream (`bool`, defaults to `False`):
+ If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
+ overlapping computation and data transfer.
+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to
+ the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html)
+ more details.
+ low_cpu_mem_usage (`bool`, defaults to `False`):
+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them.
+ This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
+ useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
+ exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
+
+ Example:
+ ```python
+ >>> from diffusers import DiffusionPipeline
+ >>> import torch
+
+ >>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
+
+ >>> pipe.enable_group_offload(
+ ... onload_device=torch.device("cuda"),
+ ... offload_device=torch.device("cpu"),
+ ... offload_type="leaf_level",
+ ... use_stream=True,
+ ... )
+ >>> image = pipe("a beautiful sunset").images[0]
+ ```
+ """
+ from ..hooks import apply_group_offloading
+
+ if isinstance(exclude_modules, str):
+ exclude_modules = [exclude_modules]
+ elif exclude_modules is None:
+ exclude_modules = []
+
+ unknown = set(exclude_modules) - self.components.keys()
+ if unknown:
+ logger.info(
+ f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
+ )
+
+ group_offload_kwargs = {
+ "onload_device": onload_device,
+ "offload_device": offload_device,
+ "offload_type": offload_type,
+ "num_blocks_per_group": num_blocks_per_group,
+ "non_blocking": non_blocking,
+ "use_stream": use_stream,
+ "record_stream": record_stream,
+ "low_cpu_mem_usage": low_cpu_mem_usage,
+ "offload_to_disk_path": offload_to_disk_path,
+ }
+ for name, component in self.components.items():
+ if name not in exclude_modules and isinstance(component, torch.nn.Module):
+ if hasattr(component, "enable_group_offload"):
+ component.enable_group_offload(**group_offload_kwargs)
+ else:
+ apply_group_offloading(module=component, **group_offload_kwargs)
+
+ if exclude_modules:
+ for module_name in exclude_modules:
+ module = getattr(self, module_name, None)
+ if module is not None and isinstance(module, torch.nn.Module):
+ module.to(onload_device)
+ logger.debug(f"Placed `{module_name}` on {onload_device} device as it was in `exclude_modules`.")
+
def reset_device_map(self):
r"""
Resets the device maps (if any) to None.
@@ -1298,11 +1496,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
- A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory
must contain a file called `pipeline.py` that defines the custom pipeline.
-
-
- 🧪 This is an experimental feature and may change in the future.
-
-
+ > [!WARNING] > 🧪 This is an experimental feature and may change in the future.
For more information on how to load and create custom pipelines, take a look at [How to contribute a
community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline).
@@ -1356,12 +1550,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
`os.PathLike`:
A path to the downloaded pipeline.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`.
-
-
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login
"""
cache_dir = kwargs.pop("cache_dir", None)
@@ -1406,7 +1596,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
- except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
+ except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
@@ -1628,10 +1818,42 @@ def _get_signature_types(cls):
signature_types[k] = (v.annotation,)
elif get_origin(v.annotation) == Union:
signature_types[k] = get_args(v.annotation)
+ elif get_origin(v.annotation) in [List, Dict, list, dict]:
+ signature_types[k] = (v.annotation,)
else:
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
return signature_types
+ @property
+ def parameters(self) -> Dict[str, Any]:
+ r"""
+ The `self.parameters` property can be useful to run different pipelines with the same weights and
+ configurations without reallocating additional memory.
+
+ Returns (`dict`):
+ A dictionary containing all the optional parameters needed to initialize the pipeline.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import (
+ ... StableDiffusionPipeline,
+ ... StableDiffusionImg2ImgPipeline,
+ ... StableDiffusionInpaintPipeline,
+ ... )
+
+ >>> text2img = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
+ >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components, **text2img.parameters)
+ >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components, **text2img.parameters)
+ ```
+ """
+ expected_modules, optional_parameters = self._get_signature_keys(self)
+ pipeline_parameters = {
+ k: self.config[k] for k in self.config.keys() if not k.startswith("_") and k in optional_parameters
+ }
+
+ return pipeline_parameters
+
@property
def components(self) -> Dict[str, Any]:
r"""
@@ -1702,12 +1924,8 @@ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Call
option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed
up during training is not guaranteed.
-
-
- ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
- precedent.
-
-
+ > [!WARNING] > ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient
+ attention takes > precedent.
Parameters:
attention_op (`Callable`, *optional*):
@@ -1763,13 +1981,10 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
in slices to compute attention in several steps. For more than one attention head, the computation is performed
sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.
-
-
- ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch
- 2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable
- this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!
-
-
+ > [!WARNING] > ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA)
+ from PyTorch > 2.0 or xFormers. These attention computations are already very memory efficient so you won't
+ need to enable > this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious
+ slow downs!
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
@@ -1914,11 +2129,13 @@ def from_pipe(cls, pipeline, **kwargs):
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
}
+ optional_components = (
+ pipeline._optional_components
+ if hasattr(pipeline, "_optional_components") and pipeline._optional_components
+ else []
+ )
missing_modules = (
- set(expected_modules)
- - set(pipeline._optional_components)
- - set(pipeline_kwargs.keys())
- - set(true_optional_modules)
+ set(expected_modules) - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules)
)
if len(missing_modules) > 0:
@@ -1965,6 +2182,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -1972,6 +2195,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -1980,6 +2209,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -1987,10 +2222,16 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
- r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+ r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
@@ -2020,11 +2261,7 @@ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
@@ -2049,11 +2286,7 @@ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index 988e049dd684..1d718a4852a4 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -1,4 +1,4 @@
-# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -256,7 +256,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`PixArtTransformer2DModel`]):
- A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents.
+ A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as
+ [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS/blob/main/transformer/config.json#L2)
+ in the config, but the mismatch can be ignored.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
@@ -437,7 +439,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -598,7 +600,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -733,11 +735,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -745,15 +747,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -832,7 +834,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
index 7f10ee89ee04..bb169ac5c443 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
@@ -1,4 +1,4 @@
-# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -185,6 +185,26 @@ def retrieve_timesteps(
class PixArtSigmaPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Sigma.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`PixArtTransformer2DModel`]):
+ A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as
+ [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS/blob/main/transformer/config.json#L2)
+ in the config, but the mismatch can be ignored.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
bad_punct_regex = re.compile(
@@ -363,7 +383,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -525,7 +545,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -660,11 +680,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -672,15 +692,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -758,7 +778,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -868,7 +888,7 @@ def __call__(
xm.mark_step()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
else:
diff --git a/src/diffusers/pipelines/prx/__init__.py b/src/diffusers/pipelines/prx/__init__.py
new file mode 100644
index 000000000000..87aaefbd1368
--- /dev/null
+++ b/src/diffusers/pipelines/prx/__init__.py
@@ -0,0 +1,63 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_prx"] = ["PRXPipeline"]
+
+# Import T5GemmaEncoder for pipeline loading compatibility
+try:
+ if is_transformers_available():
+ import transformers
+ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
+
+ _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
+ # Patch transformers module directly for serialization
+ if not hasattr(transformers, "T5GemmaEncoder"):
+ transformers.T5GemmaEncoder = T5GemmaEncoder
+except ImportError:
+ pass
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_output import PRXPipelineOutput
+ from .pipeline_prx import PRXPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/prx/pipeline_output.py b/src/diffusers/pipelines/prx/pipeline_output.py
new file mode 100644
index 000000000000..ea1bc9bf418a
--- /dev/null
+++ b/src/diffusers/pipelines/prx/pipeline_output.py
@@ -0,0 +1,35 @@
+# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class PRXPipelineOutput(BaseOutput):
+ """
+ Output class for PRX pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py
new file mode 100644
index 000000000000..873f25316e6d
--- /dev/null
+++ b/src/diffusers/pipelines/prx/pipeline_prx.py
@@ -0,0 +1,802 @@
+# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from transformers import (
+ AutoTokenizer,
+ GemmaTokenizerFast,
+ T5TokenizerFast,
+)
+from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
+
+from diffusers.image_processor import PixArtImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderDC, AutoencoderKL
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+DEFAULT_RESOLUTION = 512
+
+ASPECT_RATIO_256_BIN = {
+ "0.46": [160, 352],
+ "0.6": [192, 320],
+ "0.78": [224, 288],
+ "1.0": [256, 256],
+ "1.29": [288, 224],
+ "1.67": [320, 192],
+ "2.2": [352, 160],
+}
+
+ASPECT_RATIO_512_BIN = {
+ "0.5": [352, 704],
+ "0.57": [384, 672],
+ "0.6": [384, 640],
+ "0.68": [416, 608],
+ "0.78": [448, 576],
+ "0.88": [480, 544],
+ "1.0": [512, 512],
+ "1.13": [544, 480],
+ "1.29": [576, 448],
+ "1.46": [608, 416],
+ "1.67": [640, 384],
+ "1.75": [672, 384],
+ "2.0": [704, 352],
+}
+
+ASPECT_RATIO_1024_BIN = {
+ "0.49": [704, 1440],
+ "0.52": [736, 1408],
+ "0.53": [736, 1376],
+ "0.57": [768, 1344],
+ "0.59": [768, 1312],
+ "0.62": [800, 1280],
+ "0.67": [832, 1248],
+ "0.68": [832, 1216],
+ "0.78": [896, 1152],
+ "0.83": [928, 1120],
+ "0.94": [992, 1056],
+ "1.0": [1024, 1024],
+ "1.06": [1056, 992],
+ "1.13": [1088, 960],
+ "1.21": [1120, 928],
+ "1.29": [1152, 896],
+ "1.37": [1184, 864],
+ "1.46": [1216, 832],
+ "1.5": [1248, 832],
+ "1.71": [1312, 768],
+ "1.75": [1344, 768],
+ "1.87": [1376, 736],
+ "1.91": [1408, 736],
+ "2.05": [1440, 704],
+}
+
+ASPECT_RATIO_BINS = {
+ 256: ASPECT_RATIO_256_BIN,
+ 512: ASPECT_RATIO_512_BIN,
+ 1024: ASPECT_RATIO_1024_BIN,
+}
+
+logger = logging.get_logger(__name__)
+
+
+class TextPreprocessor:
+ """Text preprocessing utility for PRXPipeline."""
+
+ def __init__(self):
+ """Initialize text preprocessor."""
+ self.bad_punct_regex = re.compile(
+ r"["
+ + "#®•©™&@·º½¾¿¡§~"
+ + r"\)"
+ + r"\("
+ + r"\]"
+ + r"\["
+ + r"\}"
+ + r"\{"
+ + r"\|"
+ + r"\\"
+ + r"\/"
+ + r"\*"
+ + r"]{1,}"
+ )
+
+ def clean_text(self, text: str) -> str:
+ """Clean text using comprehensive text processing logic."""
+ # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
+ text = str(text)
+ text = ul.unquote_plus(text)
+ text = text.strip().lower()
+ text = re.sub("", "person", text)
+
+ # Remove all urls:
+ text = re.sub(
+ r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))",
+ "",
+ text,
+ ) # regex for urls
+
+ # @
+ text = re.sub(r"@[\w\d]+\b", "", text)
+
+ # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs
+ text = re.sub(r"[\u31c0-\u31ef]+", "", text)
+ text = re.sub(r"[\u31f0-\u31ff]+", "", text)
+ text = re.sub(r"[\u3200-\u32ff]+", "", text)
+ text = re.sub(r"[\u3300-\u33ff]+", "", text)
+ text = re.sub(r"[\u3400-\u4dbf]+", "", text)
+ text = re.sub(r"[\u4dc0-\u4dff]+", "", text)
+ text = re.sub(r"[\u4e00-\u9fff]+", "", text)
+
+ # все виды тире / all types of dash --> "-"
+ text = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
+ "-",
+ text,
+ )
+
+ # кавычки к одному стандарту
+ text = re.sub(r"[`´«»" "¨]", '"', text)
+ text = re.sub(r"['']", "'", text)
+
+ # " and &
+ text = re.sub(r""?", "", text)
+ text = re.sub(r"&", "", text)
+
+ # ip addresses:
+ text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text)
+
+ # article ids:
+ text = re.sub(r"\d:\d\d\s+$", "", text)
+
+ # \n
+ text = re.sub(r"\\n", " ", text)
+
+ # "#123", "#12345..", "123456.."
+ text = re.sub(r"#\d{1,3}\b", "", text)
+ text = re.sub(r"#\d{5,}\b", "", text)
+ text = re.sub(r"\b\d{6,}\b", "", text)
+
+ # filenames:
+ text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text)
+
+ # Clean punctuation
+ text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT"""
+ text = re.sub(r"[\.]{2,}", r" ", text)
+
+ text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ text = re.sub(r"\s+\.\s+", r" ", text) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, text)) > 3:
+ text = re.sub(regex2, " ", text)
+
+ # Basic cleaning
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ text = text.strip()
+
+ # Clean alphanumeric patterns
+ text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640
+ text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc
+ text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231
+
+ # Common spam patterns
+ text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text)
+ text = re.sub(r"(free\s)?download(\sfree)?", "", text)
+ text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text)
+ text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text)
+ text = re.sub(r"\bpage\s+\d+\b", "", text)
+
+ text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a...
+ text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text)
+
+ # Final cleanup
+ text = re.sub(r"\b\s+\:\s+", r": ", text)
+ text = re.sub(r"(\D[,\./])\b", r"\1 ", text)
+ text = re.sub(r"\s+", " ", text)
+
+ text.strip()
+
+ text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text)
+ text = re.sub(r"^[\'\_,\-\:;]", r"", text)
+ text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text)
+ text = re.sub(r"^\.\S+$", "", text)
+
+ return text.strip()
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import PRXPipeline
+
+ >>> # Load pipeline with from_pretrained
+ >>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
+ >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
+ >>> image.save("prx_output.png")
+ ```
+"""
+
+
+class PRXPipeline(
+ DiffusionPipeline,
+ LoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using PRX Transformer.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ transformer ([`PRXTransformer2DModel`]):
+ The PRX transformer model to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ text_encoder ([`T5GemmaEncoder`]):
+ Text encoder model for encoding prompts.
+ tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]):
+ Tokenizer for the text encoder.
+ vae ([`AutoencoderKL`] or [`AutoencoderDC`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _optional_components = ["vae"]
+
+ def __init__(
+ self,
+ transformer: PRXTransformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder: T5GemmaEncoder,
+ tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
+ vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None,
+ default_sample_size: Optional[int] = DEFAULT_RESOLUTION,
+ ):
+ super().__init__()
+
+ if PRXTransformer2DModel is None:
+ raise ImportError(
+ "PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed."
+ )
+
+ self.text_preprocessor = TextPreprocessor()
+ self.default_sample_size = default_sample_size
+ self._guidance_scale = 1.0
+
+ self.register_modules(
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ )
+
+ self.register_to_config(default_sample_size=self.default_sample_size)
+
+ if vae is not None:
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ else:
+ self.image_processor = None
+
+ @property
+ def vae_scale_factor(self):
+ if self.vae is None:
+ return 8
+ if hasattr(self.vae, "spatial_compression_ratio"):
+ return self.vae.spatial_compression_ratio
+ else: # Flux VAE
+ return 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ @property
+ def do_classifier_free_guidance(self):
+ """Check if classifier-free guidance is enabled based on guidance scale."""
+ return self._guidance_scale > 1.0
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ def get_default_resolution(self):
+ """Determine the default resolution based on the loaded VAE and config.
+
+ Returns:
+ int: The default sample size (height/width) to use for generation.
+ """
+ default_from_config = getattr(self.config, "default_sample_size", None)
+ if default_from_config is not None:
+ return default_from_config
+
+ return DEFAULT_RESOLUTION
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ """Prepare initial latents for the diffusion process."""
+ if latents is None:
+ spatial_compression = self.vae_scale_factor
+ latent_height, latent_width = (
+ height // spatial_compression,
+ width // spatial_compression,
+ )
+ shape = (batch_size, num_channels_latents, latent_height, latent_width)
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+ return latents
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.BoolTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
+ ):
+ """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings."""
+ if device is None:
+ device = self._execution_device
+
+ if prompt_embeds is None:
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ # Encode the prompts
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
+ self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
+ )
+
+ # Duplicate embeddings for each generation per prompt
+ if num_images_per_prompt > 1:
+ # Repeat prompt embeddings
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if prompt_attention_mask is not None:
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # Repeat negative embeddings if using CFG
+ if do_classifier_free_guidance and negative_prompt_embeds is not None:
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if negative_prompt_attention_mask is not None:
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ return (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds if do_classifier_free_guidance else None,
+ negative_prompt_attention_mask if do_classifier_free_guidance else None,
+ )
+
+ def _tokenize_prompts(self, prompts: List[str], device: torch.device):
+ """Tokenize and clean prompts."""
+ cleaned = [self.text_preprocessor.clean_text(text) for text in prompts]
+ tokens = self.tokenizer(
+ cleaned,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device)
+
+ def _encode_prompt_standard(
+ self,
+ prompt: List[str],
+ device: torch.device,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ ):
+ """Encode prompt using standard text encoder and tokenizer with batch processing."""
+ batch_size = len(prompt)
+
+ if do_classifier_free_guidance:
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+
+ prompts_to_encode = negative_prompt + prompt
+ else:
+ prompts_to_encode = prompt
+
+ input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device)
+
+ with torch.no_grad():
+ embeddings = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )["last_hidden_state"]
+
+ if do_classifier_free_guidance:
+ uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0)
+ uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0)
+ else:
+ text_embeddings = embeddings
+ cross_attn_mask = attention_mask
+ uncond_text_embeddings = None
+ uncond_cross_attn_mask = None
+
+ return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
+
+ def check_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ height: int,
+ width: int,
+ guidance_scale: float,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ """Check that all inputs are in correct format."""
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+
+ if prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None:
+ raise ValueError(
+ "When `prompt_embeds` is provided and `guidance_scale > 1.0`, "
+ "`negative_prompt_embeds` must also be provided for classifier-free guidance."
+ )
+
+ spatial_compression = self.vae_scale_factor
+ if height % spatial_compression != 0 or width % spatial_compression != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}."
+ )
+
+ if guidance_scale < 1.0:
+ raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}")
+
+ if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.BoolTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ use_resolution_binning: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
+ instead.
+ negative_prompt (`str`, *optional*, defaults to `""`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 28):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an
+ empty string.
+ prompt_attention_mask (`torch.BoolTensor`, *optional*):
+ Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated
+ from `prompt` input argument.
+ negative_prompt_attention_mask (`torch.BoolTensor`, *optional*):
+ Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`,
+ attention mask will be generated from an empty string.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple.
+ use_resolution_binning (`bool`, *optional*, defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
+ to the requested resolution. Useful for generating non-square images at optimal resolutions.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`.
+ `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed
+ in the `._callback_tensor_inputs` attribute.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # 0. Set height and width
+ default_resolution = self.get_default_resolution()
+ height = height or default_resolution
+ width = width or default_resolution
+
+ if use_resolution_binning:
+ if self.image_processor is None:
+ raise ValueError(
+ "Resolution binning requires a VAE with image_processor, but VAE is not available. "
+ "Set use_resolution_binning=False or provide a VAE."
+ )
+ if self.default_sample_size not in ASPECT_RATIO_BINS:
+ raise ValueError(
+ f"Resolution binning is only supported for default_sample_size in {list(ASPECT_RATIO_BINS.keys())}, "
+ f"but got {self.default_sample_size}. Set use_resolution_binning=False to disable aspect ratio binning."
+ )
+ aspect_ratio_bin = ASPECT_RATIO_BINS[self.default_sample_size]
+
+ # Store original dimensions
+ orig_height, orig_width = height, width
+ # Map to closest resolution in the bin
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ guidance_scale,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ if self.vae is None and output_type not in ["latent", "pt"]:
+ raise ValueError(
+ f"VAE is required for output_type='{output_type}' but it is not available. "
+ "Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs."
+ )
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Use execution device (handles offloading scenarios including group offloading)
+ device = self._execution_device
+
+ self._guidance_scale = guidance_scale
+
+ # 2. Encode input prompt
+ text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
+ prompt,
+ device,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+ # Expose standard names for callbacks parity
+ prompt_embeds = text_embeddings
+ negative_prompt_embeds = uncond_text_embeddings
+
+ # 3. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ self.num_timesteps = len(timesteps)
+
+ # 4. Prepare latent variables
+ if self.vae is not None:
+ num_channels_latents = self.vae.config.latent_channels
+ else:
+ # When vae is None, get latent channels from transformer
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare extra step kwargs
+ extra_step_kwargs = {}
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_eta:
+ extra_step_kwargs["eta"] = 0.0
+
+ # 6. Prepare cross-attention embeddings and masks
+ if self.do_classifier_free_guidance:
+ ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
+ ca_mask = None
+ if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
+ ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
+ else:
+ ca_embed = text_embeddings
+ ca_mask = cross_attn_mask
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Duplicate latents if using classifier-free guidance
+ if self.do_classifier_free_guidance:
+ latents_in = torch.cat([latents, latents], dim=0)
+ # Normalize timestep for the transformer
+ t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
+ else:
+ latents_in = latents
+ # Normalize timestep for the transformer
+ t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
+
+ # Forward through transformer
+ noise_pred = self.transformer(
+ hidden_states=latents_in,
+ timestep=t_cont,
+ encoder_hidden_states=ca_embed,
+ attention_mask=ca_mask,
+ return_dict=False,
+ )[0]
+
+ # Apply CFG
+ if self.do_classifier_free_guidance:
+ noise_uncond, noise_text = noise_pred.chunk(2, dim=0)
+ noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
+
+ # Compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_on_step_end(self, i, t, callback_kwargs)
+
+ # Call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 8. Post-processing
+ if output_type == "latent" or (output_type == "pt" and self.vae is None):
+ image = latents
+ else:
+ # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
+ shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
+ latents = (latents / scaling_factor) + shift_factor
+ # Decode using VAE (AutoencoderKL or AutoencoderDC)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ # Resize back to original resolution if using binning
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+
+ # Use standard image processor for post-processing
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return PRXPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/__init__.py b/src/diffusers/pipelines/qwenimage/__init__.py
new file mode 100644
index 000000000000..2400632ba2bd
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/__init__.py
@@ -0,0 +1,63 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["QwenImagePipelineOutput", "QwenImagePriorReduxPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
+ _import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
+ _import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"]
+ _import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
+ _import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
+ _import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
+ _import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
+ _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
+ _import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_qwenimage import QwenImagePipeline
+ from .pipeline_qwenimage_controlnet import QwenImageControlNetPipeline
+ from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline
+ from .pipeline_qwenimage_edit import QwenImageEditPipeline
+ from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
+ from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
+ from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
+ from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_output.py b/src/diffusers/pipelines/qwenimage/pipeline_output.py
new file mode 100644
index 000000000000..eef4b60e3770
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class QwenImagePipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
new file mode 100644
index 000000000000..33dc2039b986
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
@@ -0,0 +1,771 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import QwenImagePipeline
+
+ >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, num_inference_steps=50).images[0]
+ >>> image.save("qwenimage.png")
+ ```
+"""
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
+ setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
+ generate images that are closely linked to the text `prompt`, usually at the expense of lower image
+ quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
new file mode 100644
index 000000000000..5111096d93c1
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
@@ -0,0 +1,998 @@
+# Copyright 2025 Qwen-Image Team, InstantX Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import load_image
+ >>> from diffusers import QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageControlNetPipeline
+
+ >>> # QwenImageControlNetModel
+ >>> controlnet = QwenImageControlNetModel.from_pretrained(
+ ... "InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe = QwenImageControlNetPipeline.from_pretrained(
+ ... "Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
+ >>> negative_prompt = " "
+ >>> control_image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(
+ ... prompt,
+ ... negative_prompt=negative_prompt,
+ ... control_image=control_image,
+ ... controlnet_conditioning_scale=1.0,
+ ... num_inference_steps=30,
+ ... true_cfg_scale=4.0,
+ ... ).images[0]
+ >>> image.save("qwenimage_cn_union.png")
+
+ >>> # QwenImageMultiControlNetModel
+ >>> controlnet = QwenImageControlNetModel.from_pretrained(
+ ... "InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16
+ ... )
+ >>> controlnet = QwenImageMultiControlNetModel([controlnet])
+ >>> pipe = QwenImageControlNetPipeline.from_pretrained(
+ ... "Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
+ >>> negative_prompt = " "
+ >>> control_image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(
+ ... prompt,
+ ... negative_prompt=negative_prompt,
+ ... control_image=[control_image, control_image],
+ ... controlnet_conditioning_scale=[0.5, 0.5],
+ ... num_inference_steps=30,
+ ... true_cfg_scale=4.0,
+ ... ).images[0]
+ >>> image.save("qwenimage_cn_union_multi.png")
+ ```
+"""
+
+
+# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ controlnet: Union[QwenImageControlNetModel, QwenImageMultiControlNetModel],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_image: PipelineImageInput = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
+ setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
+ generate images that are closely linked to the text `prompt`, usually at the expense of lower image
+ quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 3. Prepare control image
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if isinstance(self.controlnet, QwenImageControlNetModel):
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image.shape[-2:]
+
+ if control_image.ndim == 4:
+ control_image = control_image.unsqueeze(2)
+
+ # vae encode
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(
+ device
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device
+ )
+
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
+ control_image = (control_image - latents_mean) * latents_std
+
+ control_image = control_image.permute(0, 2, 1, 3, 4)
+
+ # pack
+ control_image = self._pack_latents(
+ control_image,
+ batch_size=control_image.shape[0],
+ num_channels_latents=num_channels_latents,
+ height=control_image.shape[3],
+ width=control_image.shape[4],
+ ).to(dtype=prompt_embeds.dtype, device=device)
+
+ else:
+ if isinstance(self.controlnet, QwenImageMultiControlNetModel):
+ control_images = []
+ for control_image_ in control_image:
+ control_image_ = self.prepare_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ height, width = control_image_.shape[-2:]
+
+ if control_image_.ndim == 4:
+ control_image_ = control_image_.unsqueeze(2)
+
+ # vae encode
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)
+ ).to(device)
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
+ 1, self.vae.config.z_dim, 1, 1, 1
+ ).to(device)
+
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
+ control_image_ = (control_image_ - latents_mean) * latents_std
+
+ control_image_ = control_image_.permute(0, 2, 1, 3, 4)
+
+ # pack
+ control_image_ = self._pack_latents(
+ control_image_,
+ batch_size=control_image_.shape[0],
+ num_channels_latents=num_channels_latents,
+ height=control_image_.shape[3],
+ width=control_image_.shape[4],
+ ).to(dtype=prompt_embeds.dtype, device=device)
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # controlnet
+ controlnet_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ return_dict=False,
+ )
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
new file mode 100644
index 000000000000..102a813ab582
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
@@ -0,0 +1,941 @@
+# Copyright 2025 Qwen-Image Team, The InstantX Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import load_image
+ >>> from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline
+
+ >>> base_model_path = "Qwen/Qwen-Image"
+ >>> controlnet_model_path = "InstantX/Qwen-Image-ControlNet-Inpainting"
+ >>> controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)
+ >>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
+ ... base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... ).to("cuda")
+ >>> image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png"
+ ... )
+ >>> mask_image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png"
+ ... )
+ >>> prompt = "一辆绿色的出租车行驶在路上"
+ >>> result = pipe(
+ ... prompt=prompt,
+ ... control_image=image,
+ ... control_mask=mask_image,
+ ... controlnet_conditioning_scale=1.0,
+ ... width=mask_image.size[0],
+ ... height=mask_image.size[1],
+ ... true_cfg_scale=4.0,
+ ... ).images[0]
+ >>> image.save("qwenimage_controlnet_inpaint.png")
+ ```
+"""
+
+
+# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ controlnet: QwenImageControlNetModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ do_resize=True,
+ do_convert_grayscale=True,
+ do_normalize=False,
+ do_binarize=True,
+ )
+
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(self.device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_image_with_mask(
+ self,
+ image,
+ mask,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+ image = image.to(device=device, dtype=dtype) # (bsz, 3, height_ori, width_ori)
+
+ # Prepare mask
+ if isinstance(mask, torch.Tensor):
+ pass
+ else:
+ mask = self.mask_processor.preprocess(mask, height=height, width=width)
+ mask = mask.repeat_interleave(repeat_by, dim=0)
+ mask = mask.to(device=device, dtype=dtype) # (bsz, 1, height_ori, width_ori)
+
+ if image.ndim == 4:
+ image = image.unsqueeze(2)
+
+ if mask.ndim == 4:
+ mask = mask.unsqueeze(2)
+
+ # Get masked image
+ masked_image = image.clone()
+ masked_image[(mask > 0.5).repeat(1, 3, 1, 1, 1)] = -1 # (bsz, 3, 1, height_ori, width_ori)
+
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(device)
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device
+ )
+
+ # Encode to latents
+ image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
+ image_latents = (image_latents - latents_mean) * latents_std
+ image_latents = image_latents.to(dtype) # torch.Size([1, 16, 1, height_ori//8, width_ori//8])
+
+ mask = torch.nn.functional.interpolate(
+ mask, size=(image_latents.shape[-3], image_latents.shape[-2], image_latents.shape[-1])
+ )
+ mask = 1 - mask # torch.Size([1, 1, 1, height_ori//8, width_ori//8])
+
+ control_image = torch.cat(
+ [image_latents, mask], dim=1
+ ) # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8])
+
+ control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8])
+
+ # pack
+ control_image = self._pack_latents(
+ control_image,
+ batch_size=control_image.shape[0],
+ num_channels_latents=control_image.shape[2],
+ height=control_image.shape[3],
+ width=control_image.shape[4],
+ )
+
+ if do_classifier_free_guidance and not guess_mode:
+ control_image = torch.cat([control_image] * 2)
+
+ return control_image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 1.0,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_image: PipelineImageInput = None,
+ control_mask: PipelineImageInput = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 3. Prepare control image
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if isinstance(self.controlnet, QwenImageControlNetModel):
+ control_image = self.prepare_image_with_mask(
+ image=control_image,
+ mask=control_mask,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # controlnet
+ controlnet_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image.to(dtype=latents.dtype, device=device),
+ conditioning_scale=cond_scale,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ return_dict=False,
+ )
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
new file mode 100644
index 000000000000..ed37b238c8c9
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
@@ -0,0 +1,899 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import QwenImageEditPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ... ).convert("RGB")
+ >>> prompt = (
+ ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
+ >>> image.save("qwenimage_edit.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def calculate_dimensions(target_area, ratio):
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+
+ return width, height, None
+
+
+class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The Qwen-Image-Edit pipeline for image editing.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ processor: Qwen2VLProcessor,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ processor=processor,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 1024
+
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 64
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+
+ model_inputs = self.processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = self.text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ image (`torch.Tensor`, *optional*):
+ image to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ image_latents = (image_latents - latents_mean) / latents_std
+
+ return image_latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ image_latents = None
+ if image is not None:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latent_height, image_latent_width = image_latents.shape[3:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents, image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
+ lower image quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ image_size = image[0].size if isinstance(image, list) else image.size
+ calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
+ height = height or calculated_height
+ width = width or calculated_width
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ image = self.image_processor.resize(image, calculated_height, calculated_width)
+ prompt_image = image
+ image = self.image_processor.preprocess(image, calculated_height, calculated_width)
+ image = image.unsqueeze(2)
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ image=prompt_image,
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ image=prompt_image,
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents = self.prepare_latents(
+ image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [
+ [
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
+ (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
+ ]
+ ] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py
new file mode 100644
index 000000000000..d54d1881fa4e
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py
@@ -0,0 +1,1130 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import QwenImageEditInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageEditInpaintPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image = pipe(
+ ... prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=1.0, num_inference_steps=50
+ ... ).images[0]
+ >>> image.save("qwenimage_inpainting.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.calculate_dimensions
+def calculate_dimensions(target_area, ratio):
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+
+ return width, height, None
+
+
+class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The Qwen-Image-Edit pipeline for image editing.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ processor: Qwen2VLProcessor,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ processor=processor,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.vl_processor = processor
+ self.tokenizer_max_length = 1024
+
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 64
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+
+ model_inputs = self.processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = self.text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ image (`torch.Tensor`, *optional*):
+ image to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ image_latents.device, image_latents.dtype
+ )
+
+ image_latents = (image_latents - latents_mean) * latents_std
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it.
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W']
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W']
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device)
+ latents = noise
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, noise, image_latents
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if masked_image.dim() == 4:
+ masked_image = masked_image.unsqueeze(2)
+ elif masked_image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.")
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == self.latent_channels:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
+ lower image quality.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
+ latents tensor will ge generated by `mask_image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ image_size = image[0].size if isinstance(image, list) else image.size
+ calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
+
+ # height and width are the same as the calculated height and width
+ height = calculated_height
+ width = calculated_width
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Preprocess image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ image = self.image_processor.resize(image, calculated_height, calculated_width)
+ original_image = image
+ prompt_image = image
+ image = self.image_processor.preprocess(
+ image,
+ height=calculated_height,
+ width=calculated_width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+ image = image.to(dtype=torch.float32)
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ image=prompt_image,
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ image=prompt_image,
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, noise, image_latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ img_shapes = [
+ [
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
+ (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
+ ]
+ ] * batch_size
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # for 64 channel transformer only.
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if padding_mask_crop is not None:
+ image = [
+ self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
+ ]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
new file mode 100644
index 000000000000..ec203edf166c
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
@@ -0,0 +1,883 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import QwenImageEditPlusPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ... ).convert("RGB")
+ >>> prompt = (
+ ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
+ >>> image.save("qwenimage_edit_plus.png")
+ ```
+"""
+
+CONDITION_IMAGE_SIZE = 384 * 384
+VAE_IMAGE_SIZE = 1024 * 1024
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def calculate_dimensions(target_area, ratio):
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+
+ return width, height
+
+
+class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The Qwen-Image-Edit pipeline for image editing.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ processor: Qwen2VLProcessor,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ processor=processor,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 1024
+
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 64
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
+ if isinstance(image, list):
+ base_img_prompt = ""
+ for i, img in enumerate(image):
+ base_img_prompt += img_prompt_template.format(i + 1)
+ elif image is not None:
+ base_img_prompt = img_prompt_template.format(1)
+ else:
+ base_img_prompt = ""
+
+ template = self.prompt_template_encode
+
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(base_img_prompt + e) for e in prompt]
+
+ model_inputs = self.processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = self.text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ image (`torch.Tensor`, *optional*):
+ image to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ image_latents = (image_latents - latents_mean) / latents_std
+
+ return image_latents
+
+ def prepare_latents(
+ self,
+ images,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ image_latents = None
+ if images is not None:
+ if not isinstance(images, list):
+ images = [images]
+ all_image_latents = []
+ for image in images:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latent_height, image_latent_width = image_latents.shape[3:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ all_image_latents.append(image_latents)
+ image_latents = torch.cat(all_image_latents, dim=1)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents, image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
+ lower image quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ image_size = image[-1].size if isinstance(image, list) else image.size
+ calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
+ height = height or calculated_height
+ width = width or calculated_width
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ if not isinstance(image, list):
+ image = [image]
+ condition_image_sizes = []
+ condition_images = []
+ vae_image_sizes = []
+ vae_images = []
+ for img in image:
+ image_width, image_height = img.size
+ condition_width, condition_height = calculate_dimensions(
+ CONDITION_IMAGE_SIZE, image_width / image_height
+ )
+ vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
+ condition_image_sizes.append((condition_width, condition_height))
+ vae_image_sizes.append((vae_width, vae_height))
+ condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
+ vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ image=condition_images,
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ image=condition_images,
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents = self.prepare_latents(
+ vae_images,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [
+ [
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
+ *[
+ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
+ for vae_width, vae_height in vae_image_sizes
+ ],
+ ]
+ ] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py
new file mode 100644
index 000000000000..cb4c5d8016bb
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py
@@ -0,0 +1,874 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import QwenImageImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageImg2ImgPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
+ >>> pipe = pipe.to("cuda")
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ >>> init_image = load_image(url).resize((1024, 1024))
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney"
+ >>> images = pipe(prompt=prompt, negative_prompt=" ", image=init_image, strength=0.95).images[0]
+ >>> images.save("qwenimage_img2img.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
+ )
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ image_latents.device, image_latents.dtype
+ )
+
+ image_latents = (image_latents - latents_mean) * latents_std
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it.
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W']
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W']
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
+ setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
+ generate images that are closely linked to the text `prompt`, usually at the expense of lower image
+ quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ strength,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py
new file mode 100644
index 000000000000..1915c27eb2bb
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py
@@ -0,0 +1,1060 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import QwenImageInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageInpaintPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image = pipe(prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=0.85).images[0]
+ >>> image.save("qwenimage_inpainting.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
+ )
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ image_latents.device, image_latents.dtype
+ )
+
+ image_latents = (image_latents - latents_mean) * latents_std
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it.
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W']
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W']
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device)
+ latents = noise
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, noise, image_latents
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if masked_image.dim() == 4:
+ masked_image = masked_image.unsqueeze(2)
+ elif masked_image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.")
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == self.latent_channels:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
+ setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
+ generate images that are closely linked to the text `prompt`, usually at the expense of lower image
+ quality.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
+ latents tensor will be generated by `mask_image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+
+ latents, noise, image_latents = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # for 64 channel transformer only.
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if padding_mask_crop is not None:
+ image = [
+ self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
+ ]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py
index c5814b2eb4da..3672cb58def4 100644
--- a/src/diffusers/pipelines/sana/__init__.py
+++ b/src/diffusers/pipelines/sana/__init__.py
@@ -22,8 +22,9 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_sana"] = ["SanaPipeline"]
- _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
+ _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
+ _import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -36,6 +37,7 @@
from .pipeline_sana import SanaPipeline
from .pipeline_sana_controlnet import SanaControlNetPipeline
from .pipeline_sana_sprint import SanaSprintPipeline
+ from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index 6093fd836aad..2beff802c6e0 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -1,4 +1,4 @@
-# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 SANA Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@
from ...utils import (
BACKENDS_MAPPING,
USE_PEFT_BACKEND,
+ deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
@@ -38,7 +39,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import randn_tensor
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pixart_alpha.pipeline_pixart_alpha import (
ASPECT_RATIO_512_BIN,
@@ -224,6 +225,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -231,6 +238,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -239,6 +252,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -246,6 +265,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def _get_gemma_prompt_embeds(
@@ -354,9 +379,7 @@ def encode_prompt(
if device is None:
device = self._execution_device
- if self.transformer is not None:
- dtype = self.transformer.dtype
- elif self.text_encoder is not None:
+ if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
@@ -442,7 +465,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -602,7 +625,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -763,11 +786,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -775,15 +798,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -928,22 +951,22 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
+ transformer_dtype = self.transformer.dtype
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
- latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
+ timestep = t.expand(latent_model_input.shape[0])
timestep = timestep * self.transformer.config.timestep_scale
# predict noise model_output
noise_pred = self.transformer(
- latent_model_input,
- encoder_hidden_states=prompt_embeds,
+ latent_model_input.to(dtype=transformer_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
return_dict=False,
@@ -959,8 +982,6 @@ def __call__(
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
- else:
- noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
@@ -986,9 +1007,15 @@ def __call__(
image = latents
else:
latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
- except torch.cuda.OutOfMemoryError as e:
+ except oom_error as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
index 8a23486d6f80..55ed7b84ebdf 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
@@ -48,6 +48,7 @@
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from .pipeline_output import SanaPipelineOutput
+
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -166,13 +167,9 @@ def retrieve_timesteps(
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
- raise ValueError(
- "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
- )
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
- accepts_timesteps = "timesteps" in set(
- inspect.signature(scheduler.set_timesteps).parameters.keys()
- )
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -182,9 +179,7 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
- accept_sigmas = "sigmas" in set(
- inspect.signature(scheduler.set_timesteps).parameters.keys()
- )
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -209,12 +204,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# fmt: on
model_cpu_offload_seq = "text_encoder->controlnet->transformer->vae"
- _callback_tensor_inputs = [
- "latents",
- "control_image",
- "prompt_embeds",
- "negative_prompt_embeds",
- ]
+ _callback_tensor_inputs = ["latents", "control_image", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -241,9 +231,7 @@ def __init__(
if hasattr(self, "vae") and self.vae is not None
else 32
)
- self.image_processor = PixArtImageProcessor(
- vae_scale_factor=self.vae_scale_factor
- )
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
def enable_vae_slicing(self):
r"""
@@ -352,9 +340,7 @@ def _get_gemma_prompt_embeds(
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
- prompt_embeds = self.text_encoder(
- text_input_ids.to(device), attention_mask=prompt_attention_mask
- )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
@@ -452,51 +438,33 @@ def encode_prompt(
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(
- bs_embed * num_images_per_prompt, seq_len, -1
- )
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
- negative_prompt = (
- [negative_prompt] * batch_size
- if isinstance(negative_prompt, str)
- else negative_prompt
- )
- negative_prompt_embeds, negative_prompt_attention_mask = (
- self._get_gemma_prompt_embeds(
- prompt=negative_prompt,
- device=device,
- dtype=dtype,
- clean_caption=clean_caption,
- max_sequence_length=max_sequence_length,
- complex_human_instruction=False,
- )
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=False,
)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
- negative_prompt_embeds = negative_prompt_embeds.to(
- dtype=dtype, device=device
- )
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
- negative_prompt_embeds = negative_prompt_embeds.repeat(
- 1, num_images_per_prompt, 1
- )
- negative_prompt_embeds = negative_prompt_embeds.view(
- batch_size * num_images_per_prompt, seq_len, -1
- )
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(
- bs_embed, -1
- )
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
- num_images_per_prompt, 1
- )
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
@@ -506,12 +474,7 @@ def encode_prompt(
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
- return (
- prompt_embeds,
- prompt_attention_mask,
- negative_prompt_embeds,
- negative_prompt_attention_mask,
- )
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -520,17 +483,13 @@ def prepare_extra_step_kwargs(self, generator, eta):
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
- accepts_eta = "eta" in set(
- inspect.signature(self.scheduler.step).parameters.keys()
- )
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
- accepts_generator = "generator" in set(
- inspect.signature(self.scheduler.step).parameters.keys()
- )
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
@@ -548,13 +507,10 @@ def check_inputs(
negative_prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
- raise ValueError(
- f"`height` and `width` have to be divisible by 32 but are {height} and {width}."
- )
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
- k in self._callback_tensor_inputs
- for k in callback_on_step_end_tensor_inputs
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -569,12 +525,8 @@ def check_inputs(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
- elif prompt is not None and (
- not isinstance(prompt, str) and not isinstance(prompt, list)
- ):
- raise ValueError(
- f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
- )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -589,17 +541,10 @@ def check_inputs(
)
if prompt_embeds is not None and prompt_attention_mask is None:
- raise ValueError(
- "Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
- )
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
- if (
- negative_prompt_embeds is not None
- and negative_prompt_attention_mask is None
- ):
- raise ValueError(
- "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
- )
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
@@ -618,16 +563,12 @@ def check_inputs(
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
- logger.warning(
- BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")
- )
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
- logger.warning(
- BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")
- )
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
@@ -715,17 +656,13 @@ def _clean_caption(self, caption):
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
- caption = re.sub(
- r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
- )
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
- caption = re.sub(
- self.bad_punct_regex, r" ", caption
- ) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
@@ -743,14 +680,10 @@ def _clean_caption(self, caption):
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
- caption = re.sub(
- r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
- )
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
- caption = re.sub(
- r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
- ) # j2d1a2a...
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
@@ -801,17 +734,7 @@ def prepare_image(
return image
- def prepare_latents(
- self,
- batch_size,
- num_channels_latents,
- height,
- width,
- dtype,
- device,
- generator,
- latents=None,
- ):
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device=device, dtype=dtype)
@@ -1012,9 +935,7 @@ def __call__(
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
- height, width = self.image_processor.classify_height_width_bin(
- height, width, ratios=aspect_ratio_bin
- )
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt,
@@ -1041,11 +962,7 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
- lora_scale = (
- self.attention_kwargs.get("scale", None)
- if self.attention_kwargs is not None
- else None
- )
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# 3. Encode input prompt
(
@@ -1070,9 +987,7 @@ def __call__(
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
- prompt_attention_mask = torch.cat(
- [negative_prompt_attention_mask, prompt_attention_mask], dim=0
- )
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare control image
if isinstance(self.controlnet, SanaControlNetModel):
@@ -1116,9 +1031,7 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
- num_warmup_steps = max(
- len(timesteps) - num_inference_steps * self.scheduler.order, 0
- )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
controlnet_dtype = self.controlnet.dtype
@@ -1128,11 +1041,7 @@ def __call__(
if self.interrupt:
continue
- latent_model_input = (
- torch.cat([latents] * 2)
- if self.do_classifier_free_guidance
- else latents
- )
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
@@ -1157,27 +1066,21 @@ def __call__(
timestep=timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
- controlnet_block_samples=tuple(
- t.to(dtype=transformer_dtype) for t in controlnet_block_samples
- ),
+ controlnet_block_samples=tuple(t.to(dtype=transformer_dtype) for t in controlnet_block_samples),
)[0]
noise_pred = noise_pred.float()
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (
- noise_pred_text - noise_pred_uncond
- )
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
# compute previous image: x_t -> x_t-1
- latents = self.scheduler.step(
- noise_pred, t, latents, **extra_step_kwargs, return_dict=False
- )[0]
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -1187,14 +1090,10 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop(
- "negative_prompt_embeds", negative_prompt_embeds
- )
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
- if i == len(timesteps) - 1 or (
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
- ):
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
@@ -1211,9 +1110,7 @@ def __call__(
else torch_accelerator_module.OutOfMemoryError
)
try:
- image = self.vae.decode(
- latents / self.vae.config.scaling_factor, return_dict=False
- )[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except oom_error as e:
warnings.warn(
f"{e}. \n"
@@ -1221,9 +1118,7 @@ def __call__(
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning:
- image = self.image_processor.resize_and_crop_tensor(
- image, orig_width, orig_height
- )
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
image = self.image_processor.postprocess(image, output_type=output_type)
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
index 9b3acbb1cb22..04f45f817efb 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 SANA-Sprint Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@
from ...utils import (
BACKENDS_MAPPING,
USE_PEFT_BACKEND,
+ deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
@@ -38,7 +39,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import randn_tensor
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
from .pipeline_output import SanaPipelineOutput
@@ -175,6 +176,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -182,6 +189,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -190,6 +203,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -197,6 +216,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
@@ -295,9 +320,7 @@ def encode_prompt(
if device is None:
device = self._execution_device
- if self.transformer is not None:
- dtype = self.transformer.dtype
- elif self.text_encoder is not None:
+ if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
@@ -349,7 +372,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -493,7 +516,7 @@ def _clean_caption(self, caption):
# &
caption = re.sub(r"&", "", caption)
- # ip adresses:
+ # ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -645,11 +668,11 @@ def __call__(
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
@@ -657,15 +680,15 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -806,13 +829,14 @@ def __call__(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
+ transformer_dtype = self.transformer.dtype
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
+ timestep = t.expand(latents.shape[0])
latents_model_input = latents / self.scheduler.config.sigma_data
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
@@ -821,12 +845,11 @@ def __call__(
latent_model_input = latents_model_input * torch.sqrt(
scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
)
- latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# predict noise model_output
noise_pred = self.transformer(
- latent_model_input,
- encoder_hidden_states=prompt_embeds,
+ latent_model_input.to(dtype=transformer_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
encoder_attention_mask=prompt_attention_mask,
guidance=guidance,
timestep=scm_timestep,
@@ -866,9 +889,15 @@ def __call__(
image = latents
else:
latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
- except torch.cuda.OutOfMemoryError as e:
+ except oom_error as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
new file mode 100644
index 000000000000..8899ed84c4e5
--- /dev/null
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
@@ -0,0 +1,1006 @@
+# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput, PixArtImageProcessor
+from ...loaders import SanaLoraLoaderMixin
+from ...models import AutoencoderDC, SanaTransformer2DModel
+from ...schedulers import DPMSolverMultistepScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
+from .pipeline_output import SanaPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaSprintImg2ImgPipeline
+ >>> from diffusers.utils.loading_utils import load_image
+
+ >>> pipe = SanaSprintImg2ImgPipeline.from_pretrained(
+ ... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
+ ... )
+
+
+ >>> image = pipe(prompt="a cute pink bear", image=image, strength=0.5, height=832, width=480).images[0]
+ >>> image[0].save("output.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(
+ r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: AutoencoderDC,
+ transformer: SanaTransformer2DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
+ if hasattr(self, "vae") and self.vae is not None
+ else 32
+ )
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ dtype: torch.dtype,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_sequence_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana_sprint.SanaSprintPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ prompt_embeds = prompt_embeds[:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ height,
+ width,
+ num_inference_steps,
+ timesteps,
+ max_timesteps,
+ intermediate_timesteps,
+ callback_on_step_end_tensor_inputs=None,
+ prompt_embeds=None,
+ prompt_attention_mask=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if timesteps is not None and len(timesteps) != num_inference_steps + 1:
+ raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
+
+ if timesteps is not None and max_timesteps is not None:
+ raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
+
+ if timesteps is None and max_timesteps is None:
+ raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
+
+ if intermediate_timesteps is not None and num_inference_steps != 2:
+ raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip addresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_image(
+ self,
+ image: PipelineImageInput,
+ width: int,
+ height: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ):
+ if isinstance(image, torch.Tensor):
+ if image.ndim == 3:
+ image = image.unsqueeze(0)
+ # Resize if current dimensions do not match target dimensions.
+ if image.shape[2] != height or image.shape[3] != width:
+ image = F.interpolate(image, size=(height, width), mode="bilinear", align_corners=False)
+
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image = image.to(device=device, dtype=dtype)
+
+ return image
+
+ def prepare_latents(
+ self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None
+ ):
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+
+ if image.shape[1] != num_channels_latents:
+ image = self.vae.encode(image).latent
+ image_latents = image * self.vae.config.scaling_factor * self.scheduler.config.sigma_data
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # adapt from https://github.com/huggingface/diffusers/blob/c36f8487df35895421c15f351c7d360bd680[…]/examples/research_projects/sana/train_sana_sprint_diffusers.py
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * self.scheduler.config.sigma_data
+ latents = torch.cos(timestep) * image_latents + torch.sin(timestep) * noise
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 2,
+ timesteps: List[int] = None,
+ max_timesteps: float = 1.57080,
+ intermediate_timesteps: float = 1.3,
+ guidance_scale: float = 4.5,
+ image: PipelineImageInput = None,
+ strength: float = 0.6,
+ num_images_per_prompt: Optional[int] = 1,
+ height: int = 1024,
+ width: int = 1024,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ ) -> Union[SanaPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ num_inference_steps (`int`, *optional*, defaults to 20):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ max_timesteps (`float`, *optional*, defaults to 1.57080):
+ The maximum timestep value used in the SCM scheduler.
+ intermediate_timesteps (`float`, *optional*, defaults to 1.3):
+ The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
+ the requested resolution. Useful for generating non-square images.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `300`):
+ Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 32:
+ aspect_ratio_bin = ASPECT_RATIO_1024_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt=prompt,
+ strength=strength,
+ height=height,
+ width=width,
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ max_timesteps=max_timesteps,
+ intermediate_timesteps=intermediate_timesteps,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ # 2. Preprocess image
+ init_image = self.prepare_image(image, width, height, device, self.vae.dtype)
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
+ )
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas=None,
+ max_timesteps=max_timesteps,
+ intermediate_timesteps=intermediate_timesteps,
+ )
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(0)
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1]
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # I think this is redundant given the scaling in prepare_latents
+ # latents = latents * self.scheduler.config.sigma_data
+
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
+ guidance = guidance * self.transformer.config.guidance_embeds_scale
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ timesteps = timesteps[:-1]
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ transformer_dtype = self.transformer.dtype
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+ latents_model_input = latents / self.scheduler.config.sigma_data
+
+ scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
+
+ scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
+ latent_model_input = latents_model_input * torch.sqrt(
+ scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
+ )
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input.to(dtype=transformer_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
+ encoder_attention_mask=prompt_attention_mask,
+ guidance=guidance,
+ timestep=scm_timestep,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+
+ noise_pred = (
+ (1 - 2 * scm_timestep_expanded) * latent_model_input
+ + (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
+ ) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
+ noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
+
+ # compute previous image: x_t -> x_t-1
+ latents, denoised = self.scheduler.step(
+ noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
+ )
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ latents = denoised / self.scheduler.config.sigma_data
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
+ try:
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ except oom_error as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+
+ if not output_type == "latent":
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return SanaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/sana_video/__init__.py b/src/diffusers/pipelines/sana_video/__init__.py
new file mode 100644
index 000000000000..73e224bf749d
--- /dev/null
+++ b/src/diffusers/pipelines/sana_video/__init__.py
@@ -0,0 +1,49 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
+ _import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_sana_video import SanaVideoPipeline
+ from .pipeline_sana_video_i2v import SanaImageToVideoPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/sana_video/pipeline_output.py b/src/diffusers/pipelines/sana_video/pipeline_output.py
new file mode 100644
index 000000000000..4d37923889eb
--- /dev/null
+++ b/src/diffusers/pipelines/sana_video/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class SanaVideoPipelineOutput(BaseOutput):
+ r"""
+ Output class for Sana-Video pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py
new file mode 100644
index 000000000000..a786275e45a9
--- /dev/null
+++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py
@@ -0,0 +1,1017 @@
+# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SanaLoraLoaderMixin
+from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
+from ...schedulers import DPMSolverMultistepScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SanaVideoPipelineOutput
+
+
+ASPECT_RATIO_480_BIN = {
+ "0.5": [448.0, 896.0],
+ "0.57": [480.0, 832.0],
+ "0.68": [528.0, 768.0],
+ "0.78": [560.0, 720.0],
+ "1.0": [624.0, 624.0],
+ "1.13": [672.0, 592.0],
+ "1.29": [720.0, 560.0],
+ "1.46": [768.0, 528.0],
+ "1.67": [816.0, 496.0],
+ "1.75": [832.0, 480.0],
+ "2.0": [896.0, 448.0],
+}
+
+
+ASPECT_RATIO_720_BIN = {
+ "0.5": [672.0, 1344.0],
+ "0.57": [704.0, 1280.0],
+ "0.68": [800.0, 1152.0],
+ "0.78": [832.0, 1088.0],
+ "1.0": [960.0, 960.0],
+ "1.13": [1024.0, 896.0],
+ "1.29": [1088.0, 832.0],
+ "1.46": [1152.0, 800.0],
+ "1.67": [1248.0, 736.0],
+ "1.75": [1280.0, 704.0],
+ "2.0": [1344.0, 672.0],
+}
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaVideoPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers")
+ >>> pipe.transformer.to(torch.bfloat16)
+ >>> pipe.text_encoder.to(torch.bfloat16)
+ >>> pipe.vae.to(torch.float32)
+ >>> pipe.to("cuda")
+ >>> motion_score = 30
+
+ >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
+ >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+ >>> motion_prompt = f" motion score: {motion_score}."
+ >>> prompt = prompt + motion_prompt
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... frames=81,
+ ... guidance_scale=6,
+ ... num_inference_steps=50,
+ ... generator=torch.Generator(device="cuda").manual_seed(42),
+ ... ).frames[0]
+
+ >>> export_to_video(output, "sana-video-output.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model inherits
+ from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all
+ pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]):
+ The tokenizer used to tokenize the prompt.
+ text_encoder ([`Gemma2PreTrainedModel`]):
+ Text encoder model to encode the input prompts.
+ vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer ([`SanaVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`DPMSolverMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: Union[AutoencoderDC, AutoencoderKLWan],
+ transformer: SanaVideoTransformer3DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.vae_scale_factor = self.vae_scale_factor_spatial
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ dtype: torch.dtype,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_sequence_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ number of videos that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ prompt_embeds = prompt_embeds[:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=False,
+ )
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip addresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ height: int = 480,
+ width: int = 832,
+ frames: int = 81,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ ) -> Union[SanaVideoPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to
+ the text `prompt`, usually at the expense of lower video quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ height (`int`, *optional*, defaults to 480):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to 832):
+ The width in pixels of the generated video.
+ frames (`int`, *optional*, defaults to 81):
+ The number of frames in the generated video.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between mp4 or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos,
+ they are resized back to the requested resolution. Useful for generating non-square videos.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `300`):
+ Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated videos
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 30:
+ aspect_ratio_bin = ASPECT_RATIO_480_BIN
+ elif self.transformer.config.sample_size == 22:
+ aspect_ratio_bin = ASPECT_RATIO_720_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ height,
+ width,
+ frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ transformer_dtype = self.transformer.dtype
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input.to(dtype=transformer_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timestep,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ try:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ except oom_error as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+
+ if use_resolution_binning:
+ video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height)
+
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SanaVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py
new file mode 100644
index 000000000000..e87880b64cee
--- /dev/null
+++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py
@@ -0,0 +1,1066 @@
+# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import SanaLoraLoaderMixin
+from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SanaVideoPipelineOutput
+from .pipeline_sana_video import ASPECT_RATIO_480_BIN, ASPECT_RATIO_720_BIN
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaImageToVideoPipeline
+ >>> from diffusers.utils import export_to_video, load_image
+
+ >>> pipe = SanaImageToVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers")
+ >>> pipe.transformer.to(torch.bfloat16)
+ >>> pipe.text_encoder.to(torch.bfloat16)
+ >>> pipe.vae.to(torch.float32)
+ >>> pipe.to("cuda")
+ >>> motion_score = 30
+
+ >>> prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
+ >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+ >>> motion_prompt = f" motion score: {motion_score}."
+ >>> prompt = prompt + motion_prompt
+ >>> image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... frames=81,
+ ... guidance_scale=6,
+ ... num_inference_steps=50,
+ ... generator=torch.Generator(device="cuda").manual_seed(42),
+ ... ).frames[0]
+
+ >>> export_to_video(output, "sana-ti2v-output.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
+ r"""
+ Pipeline for image/text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model
+ inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all
+ pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]):
+ The tokenizer used to tokenize the prompt.
+ text_encoder ([`Gemma2PreTrainedModel`]):
+ Text encoder model to encode the input prompts.
+ vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer ([`SanaVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: Union[AutoencoderDC, AutoencoderKLWan],
+ transformer: SanaVideoTransformer3DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.vae_scale_factor = self.vae_scale_factor_spatial
+
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size[1] if getattr(self, "transformer", None) is not None else 1
+ )
+ self.transformer_temporal_patch_size = (
+ self.transformer.config.patch_size[0] if getattr(self, "transformer") is not None else 1
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ dtype: torch.dtype,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_sequence_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Copied from diffusers.pipelines.sana_video.pipeline_sana_video.SanaVideoPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ number of videos that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ prompt_embeds = prompt_embeds[:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=False,
+ )
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip addresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ image = image.unsqueeze(2) # [B, C, 1, H, W]
+ image = image.to(device=device, dtype=self.vae.dtype)
+
+ if isinstance(generator, list):
+ image_latents = [retrieve_latents(self.vae.encode(image), sample_mode="argmax") for _ in generator]
+ image_latents = torch.cat(image_latents)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
+ image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, -1, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
+ image_latents.device, image_latents.dtype
+ )
+ image_latents = (image_latents - latents_mean) * latents_std
+
+ latents[:, :, 0:1] = image_latents.to(dtype)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ height: int = 480,
+ width: int = 832,
+ frames: int = 81,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ ) -> Union[SanaVideoPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the video generation on. The first frame of the generated video will be
+ conditioned on this image.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to
+ the text `prompt`, usually at the expense of lower video quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ height (`int`, *optional*, defaults to 480):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to 832):
+ The width in pixels of the generated video.
+ frames (`int`, *optional*, defaults to 81):
+ The number of frames in the generated video.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between mp4 or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos,
+ they are resized back to the requested resolution. Useful for generating non-square videos.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `300`):
+ Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is a list with the generated videos
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 30:
+ aspect_ratio_bin = ASPECT_RATIO_480_BIN
+ elif self.transformer.config.sample_size == 22:
+ aspect_ratio_bin = ASPECT_RATIO_720_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ image,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+
+ latents = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ height,
+ width,
+ frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ conditioning_mask = latents.new_zeros(
+ batch_size,
+ 1,
+ latents.shape[2] // self.transformer_temporal_patch_size,
+ latents.shape[3] // self.transformer_spatial_patch_size,
+ latents.shape[4] // self.transformer_spatial_patch_size,
+ )
+ conditioning_mask[:, :, 0] = 1.0
+ if self.do_classifier_free_guidance:
+ conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ transformer_dtype = self.transformer.dtype
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(conditioning_mask.shape)
+ timestep = timestep * (1 - conditioning_mask)
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input.to(dtype=transformer_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timestep,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+ timestep, _ = timestep.chunk(2)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+
+ noise_pred = noise_pred[:, :, 1:]
+ noise_latents = latents[:, :, 1:]
+ pred_latents = self.scheduler.step(
+ noise_pred, t, noise_latents, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ try:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ except oom_error as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+
+ if use_resolution_binning:
+ video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height)
+
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SanaVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
index a8c374259349..49b09e205cc5 100644
--- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
+++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
@@ -11,7 +11,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import SemanticStableDiffusionPipelineOutput
@@ -25,7 +25,8 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
+class SemanticStableDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
+ _last_supported_version = "0.33.1"
r"""
Pipeline for text-to-image generation using Stable Diffusion with latent editing.
@@ -47,8 +48,8 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`Q16SafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -129,7 +130,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -270,8 +271,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -331,7 +332,7 @@ def __call__(
>>> from diffusers import SemanticStableDiffusionPipeline
>>> pipe = SemanticStableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
@@ -451,7 +452,7 @@ def __call__(
edit_concepts = edit_concepts.view(bs_embed_edit * num_images_per_prompt, seq_len_edit, -1)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
diff --git a/src/diffusers/pipelines/shap_e/camera.py b/src/diffusers/pipelines/shap_e/camera.py
index d4b94c3000d8..31e1759d6154 100644
--- a/src/diffusers/pipelines/shap_e/camera.py
+++ b/src/diffusers/pipelines/shap_e/camera.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
index ef8a95daefa4..49ddfd1196bf 100644
--- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
+++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
index c0d1e38e0994..55d8b85822c4 100644
--- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
+++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/shap_e/renderer.py b/src/diffusers/pipelines/shap_e/renderer.py
index 9d9f9d9b2ab1..d1d05c894595 100644
--- a/src/diffusers/pipelines/shap_e/renderer.py
+++ b/src/diffusers/pipelines/shap_e/renderer.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Open AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Open AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -54,7 +54,7 @@ def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.T
"""
Concatenate x and its positional encodings, following NeRF.
- Reference: https://arxiv.org/pdf/2210.04628.pdf
+ Reference: https://huggingface.co/papers/2210.04628
"""
if min_deg == max_deg:
return x
@@ -742,7 +742,7 @@ class ShapEParamsProjModel(ModelMixin, ConfigMixin):
def __init__(
self,
*,
- param_names: Tuple[str] = (
+ param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
@@ -786,13 +786,13 @@ class ShapERenderer(ModelMixin, ConfigMixin):
def __init__(
self,
*,
- param_names: Tuple[str] = (
+ param_names: Tuple[str, ...] = (
"nerstf.mlp.0.weight",
"nerstf.mlp.1.weight",
"nerstf.mlp.2.weight",
"nerstf.mlp.3.weight",
),
- param_shapes: Tuple[Tuple[int]] = (
+ param_shapes: Tuple[Tuple[int, int], ...] = (
(256, 93),
(256, 256),
(256, 256),
@@ -804,7 +804,7 @@ def __init__(
n_hidden_layers: int = 6,
act_fn: str = "swish",
insert_direction_at: int = 4,
- background: Tuple[float] = (
+ background: Tuple[float, ...] = (
255.0,
255.0,
255.0,
@@ -983,9 +983,9 @@ def decode_to_mesh(
fields = torch.cat(fields, dim=1)
fields = fields.float()
- assert (
- len(fields.shape) == 3 and fields.shape[-1] == 1
- ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
+ assert len(fields.shape) == 3 and fields.shape[-1] == 1, (
+ f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}"
+ )
fields = fields.reshape(1, *([grid_size] * 3))
@@ -1038,10 +1038,10 @@ def decode_to_mesh(
textures = _convert_srgb_to_linear(textures)
textures = textures.float()
- # 3.3 augument the mesh with texture data
- assert len(textures.shape) == 3 and textures.shape[-1] == len(
- texture_channels
- ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
+ # 3.3 augment the mesh with texture data
+ assert len(textures.shape) == 3 and textures.shape[-1] == len(texture_channels), (
+ f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}"
+ )
for m, texture in zip(raw_meshes, textures):
texture = texture[: len(m.verts)]
diff --git a/src/diffusers/pipelines/skyreels_v2/__init__.py b/src/diffusers/pipelines/skyreels_v2/__init__.py
new file mode 100644
index 000000000000..84d2a2dd3500
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/__init__.py
@@ -0,0 +1,59 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_skyreels_v2"] = ["SkyReelsV2Pipeline"]
+ _import_structure["pipeline_skyreels_v2_diffusion_forcing"] = ["SkyReelsV2DiffusionForcingPipeline"]
+ _import_structure["pipeline_skyreels_v2_diffusion_forcing_i2v"] = [
+ "SkyReelsV2DiffusionForcingImageToVideoPipeline"
+ ]
+ _import_structure["pipeline_skyreels_v2_diffusion_forcing_v2v"] = [
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline"
+ ]
+ _import_structure["pipeline_skyreels_v2_i2v"] = ["SkyReelsV2ImageToVideoPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_skyreels_v2 import SkyReelsV2Pipeline
+ from .pipeline_skyreels_v2_diffusion_forcing import SkyReelsV2DiffusionForcingPipeline
+ from .pipeline_skyreels_v2_diffusion_forcing_i2v import SkyReelsV2DiffusionForcingImageToVideoPipeline
+ from .pipeline_skyreels_v2_diffusion_forcing_v2v import SkyReelsV2DiffusionForcingVideoToVideoPipeline
+ from .pipeline_skyreels_v2_i2v import SkyReelsV2ImageToVideoPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_output.py b/src/diffusers/pipelines/skyreels_v2/pipeline_output.py
new file mode 100644
index 000000000000..7a170d24c39a
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class SkyReelsV2PipelineOutput(BaseOutput):
+ r"""
+ Output class for SkyReelsV2 pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
new file mode 100644
index 000000000000..d6cd7d7feceb
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
@@ -0,0 +1,610 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import regex as re
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2Pipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2Pipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ r"""
+ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `544`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `960`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `97`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length for the text encoder.
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
new file mode 100644
index 000000000000..089f92632d38
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
@@ -0,0 +1,978 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import math
+import re
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2DiffusionForcingPipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... num_inference_steps=30,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
+ ... causal_block_size=5, # Number of frames processed together in a causal block
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
+ ... addnoise_condition=20, # Improves consistency in long video generation
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ """
+ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
+
+ Args:
+ tokenizer ([`AutoTokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`UMT5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ overlap_history=None,
+ num_frames=None,
+ base_num_frames=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if num_frames > base_num_frames and overlap_history is None:
+ raise ValueError(
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 97,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ base_latent_num_frames: Optional[int] = None,
+ video_latents: Optional[torch.Tensor] = None,
+ causal_block_size: Optional[int] = None,
+ overlap_history_latent_frames: Optional[int] = None,
+ long_video_iter: Optional[int] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ prefix_video_latents = None
+ prefix_video_latents_frames = 0
+
+ if video_latents is not None: # long video generation at the iterations other than the first one
+ prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:]
+
+ if prefix_video_latents.shape[2] % causal_block_size != 0:
+ truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size
+ logger.warning(
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
+ )
+ prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents]
+ prefix_video_latents_frames = prefix_video_latents.shape[2]
+
+ finished_frame_num = (
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames)
+ + overlap_history_latent_frames
+ )
+ left_frame_num = num_latent_frames - finished_frame_num
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
+ elif base_latent_num_frames is not None: # long video generation at the first iteration
+ num_latent_frames = base_latent_num_frames
+ else: # short video generation
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ latent_height,
+ latent_width,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames
+
+ def generate_timestep_matrix(
+ self,
+ num_latent_frames: int,
+ step_template: torch.Tensor,
+ base_num_latent_frames: int,
+ ar_step: int = 5,
+ num_pre_ready: int = 0,
+ causal_block_size: int = 1,
+ shrink_interval_with_mask: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
+ """
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
+
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
+ - All frames are denoised simultaneously at each timestep
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
+ - Simpler but may have less temporal consistency for long videos
+
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
+ - Frames are grouped into causal blocks and processed block/chunk-wise
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
+ - Creates stronger temporal dependencies and better consistency
+
+ Args:
+ num_latent_frames (int): Total number of latent frames to generate
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
+ ar_step (int, optional): Autoregressive step size for temporal lag.
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
+ num_pre_ready (int, optional):
+ Number of frames already denoised (e.g., from prefix in a video2video task).
+ Defaults to 0.
+ causal_block_size (int, optional): Number of frames processed as a causal block.
+ Defaults to 1.
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
+ Defaults to False.
+
+ Returns:
+ tuple containing:
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
+ [num_iterations, num_latent_frames]
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
+ num_latent_frames]
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
+ [num_iterations, num_latent_frames]
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
+
+ Raises:
+ ValueError: If ar_step is too small for the given configuration
+ """
+ # Initialize lists to store the scheduling matrices and metadata
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
+
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
+ num_iterations = len(step_template) + 1
+
+ # Convert frame counts to block counts for causal processing
+ # Each block contains causal_block_size frames that are processed together
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
+ num_blocks = num_latent_frames // causal_block_size
+ base_num_blocks = base_num_latent_frames // causal_block_size
+
+ # Validate ar_step is sufficient for the given configuration
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
+ if base_num_blocks < num_blocks:
+ min_ar_step = len(step_template) / base_num_blocks
+ if ar_step < min_ar_step:
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
+
+ # Extend step_template with boundary values for easier indexing
+ # 999: dummy value for counter starting from 1
+ # 0: final timestep (completely denoised)
+ step_template = torch.cat(
+ [
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
+ step_template.long(),
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
+ ]
+ )
+
+ # Initialize the previous row state (tracks denoising progress for each block)
+ # 0 means not started, num_iterations means fully denoised
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
+ if num_pre_ready > 0:
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
+
+ # Main loop: Generate denoising schedule until all frames are fully denoised
+ while not torch.all(pre_row >= (num_iterations - 1)):
+ # Create new row representing the next denoising step
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Apply diffusion forcing logic for each block
+ for i in range(num_blocks):
+ if i == 0 or pre_row[i - 1] >= (
+ num_iterations - 1
+ ): # the first frame or the last frame is completely denoised
+ new_row[i] = pre_row[i] + 1
+ else:
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
+ # This creates the "diffusion forcing" staggered pattern
+ new_row[i] = new_row[i - 1] - ar_step
+
+ # Clamp values to valid range [0, num_iterations]
+ new_row = new_row.clamp(0, num_iterations)
+
+ # Create update mask: True for blocks that need denoising update at this iteration
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
+ # Final state example: [False, ..., False, True, True, True, True, True]
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
+
+ # Store the iteration state
+ step_index.append(new_row) # Index into step_template
+ step_matrix.append(step_template[new_row]) # Actual timestep values
+ pre_row = new_row # Update for next iteration
+
+ # For videos longer than model capacity, we process in sliding windows
+ terminal_flag = base_num_blocks
+
+ # Optional optimization: shrink interval based on first update mask
+ if shrink_interval_with_mask:
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
+ update_mask = update_mask[0]
+ update_mask_idx = idx_sequence[update_mask]
+ last_update_idx = update_mask_idx[-1].item()
+ terminal_flag = last_update_idx + 1
+
+ # Each interval defines which frames to process in the current forward pass
+ for curr_mask in update_mask:
+ # Extend terminal flag if current mask has updates beyond current terminal
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
+ terminal_flag += 1
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
+
+ # Convert lists to tensors for efficient processing
+ step_update_mask = torch.stack(update_mask, dim=0)
+ step_index = torch.stack(step_index, dim=0)
+ step_matrix = torch.stack(step_matrix, dim=0)
+
+ # Each block's schedule is replicated to all frames within that block
+ if causal_block_size > 1:
+ # Expand each block to causal_block_size frames
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ # Scale intervals from block-level to frame-level
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
+
+ return step_matrix, step_index, step_update_mask, valid_interval
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ overlap_history: Optional[int] = None,
+ addnoise_condition: float = 0,
+ base_num_frames: int = 97,
+ ar_step: int = 0,
+ causal_block_size: Optional[int] = None,
+ fps: int = 24,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `544`):
+ The height of the generated video.
+ width (`int`, defaults to `960`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `97`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+ overlap_history (`int`, *optional*, defaults to `None`):
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
+ addnoise_condition (`float`, *optional*, defaults to `0`):
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
+ ones, but it is recommended to not exceed 50.
+ base_num_frames (`int`, *optional*, defaults to `97`):
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
+ ar_step (`int`, *optional*, defaults to `0`):
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
+ inference may improve the instruction following and visual consistent performance.
+ causal_block_size (`int`, *optional*, defaults to `None`):
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
+ 0)
+ fps (`int`, *optional*, defaults to `24`):
+ Frame rate of the generated video
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ overlap_history,
+ num_frames,
+ base_num_frames,
+ )
+
+ if addnoise_condition > 60:
+ logger.warning(
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ if causal_block_size is None:
+ causal_block_size = self.transformer.config.num_frame_per_block
+ else:
+ self.transformer._set_ar_attention(causal_block_size)
+
+ fps_embeds = [fps] * prompt_embeds.shape[0]
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
+
+ # Determine if we're doing long video generation
+ is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames
+ # Initialize accumulated_latents to store all latents in one tensor
+ accumulated_latents = None
+ if is_long_video:
+ # Long video generation setup
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ base_latent_num_frames = (
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
+ if base_num_frames is not None
+ else num_latent_frames
+ )
+ n_iter = (
+ 1
+ + (num_latent_frames - base_latent_num_frames - 1)
+ // (base_latent_num_frames - overlap_history_latent_frames)
+ + 1
+ )
+ else:
+ # Short video generation setup
+ n_iter = 1
+ base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ # Loop through iterations (multiple iterations only for long videos)
+ for iter_idx in range(n_iter):
+ if is_long_video:
+ logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...")
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = (
+ self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents if iter_idx == 0 else None,
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
+ base_latent_num_frames=base_latent_num_frames if is_long_video else None,
+ causal_block_size=causal_block_size,
+ overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None,
+ long_video_iter=iter_idx if is_long_video else None,
+ )
+ )
+
+ if prefix_video_latents_frames > 0:
+ latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype)
+
+ # 6. Prepare sample schedulers and timestep matrix
+ sample_schedulers = []
+ for _ in range(current_num_latent_frames):
+ sample_scheduler = deepcopy(self.scheduler)
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
+ sample_schedulers.append(sample_scheduler)
+
+ # Different matrix generation for short vs long video
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
+ current_num_latent_frames,
+ timesteps,
+ current_num_latent_frames if is_long_video else base_latent_num_frames,
+ ar_step,
+ prefix_video_latents_frames,
+ causal_block_size,
+ )
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(step_matrix)
+
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
+ for i, t in enumerate(step_matrix):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ valid_interval_start, valid_interval_end = valid_interval[i]
+ latent_model_input = (
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
+ )
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
+
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
+ noise_factor = 0.001 * addnoise_condition
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ * (1.0 - noise_factor)
+ + torch.randn_like(
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ )
+ * noise_factor
+ )
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ update_mask_i = step_update_mask[i]
+ for idx in range(valid_interval_start, valid_interval_end):
+ if update_mask_i[idx].item():
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
+ noise_pred[:, :, idx - valid_interval_start, :, :],
+ t[idx],
+ latents[:, :, idx, :, :],
+ return_dict=False,
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(step_matrix) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # Handle latent accumulation for long videos or use the current latents for short videos
+ if is_long_video:
+ if accumulated_latents is None:
+ accumulated_latents = latents
+ else:
+ # Keep overlap frames for conditioning but don't include them in final output
+ accumulated_latents = torch.cat(
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2
+ )
+
+ if is_long_video:
+ latents = accumulated_latents
+
+ self._current_timestep = None
+
+ # Final decoding step - convert latents to pixels
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
new file mode 100644
index 000000000000..2951a9447386
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
@@ -0,0 +1,1059 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import math
+import re
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import ftfy
+import PIL
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from diffusers.image_processor import PipelineImageInput
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2DiffusionForcingImageToVideoPipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+ >>> from PIL import Image
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+ >>> image = Image.open("path/to/image.png")
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ar_step=0, # Controls asynchronous inference (0 for synchronous mode)
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
+ ... addnoise_condition=20, # Improves consistency in long video generation
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ """
+ Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
+
+ Args:
+ tokenizer ([`AutoTokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`UMT5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ overlap_history=None,
+ num_frames=None,
+ base_num_frames=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if num_frames > base_num_frames and overlap_history is None:
+ raise ValueError(
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
+ )
+
+ def prepare_latents(
+ self,
+ image: Optional[PipelineImageInput],
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 97,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ video_latents: Optional[torch.Tensor] = None,
+ base_latent_num_frames: Optional[int] = None,
+ causal_block_size: Optional[int] = None,
+ overlap_history_latent_frames: Optional[int] = None,
+ long_video_iter: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ prefix_video_latents_frames = 0
+
+ if video_latents is not None: # long video generation at the iterations other than the first one
+ condition = video_latents[:, :, -overlap_history_latent_frames:]
+
+ if condition.shape[2] % causal_block_size != 0:
+ truncate_len_latents = condition.shape[2] % causal_block_size
+ logger.warning(
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
+ )
+ condition = condition[:, :, :-truncate_len_latents]
+ prefix_video_latents_frames = condition.shape[2]
+
+ finished_frame_num = (
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames)
+ + overlap_history_latent_frames
+ )
+ left_frame_num = num_latent_frames - finished_frame_num
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
+ elif base_latent_num_frames is not None: # long video generation at the first iteration
+ num_latent_frames = base_latent_num_frames
+ else: # short video generation
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ if image is not None:
+ image = image.unsqueeze(2)
+ if last_image is not None:
+ last_image = last_image.unsqueeze(2)
+ video_condition = torch.cat([image, last_image], dim=0)
+ else:
+ video_condition = image
+
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat_interleave(batch_size, dim=0)
+
+ latent_condition = latent_condition.to(dtype)
+ condition = (latent_condition - latents_mean) * latents_std
+ prefix_video_latents_frames = condition.shape[2]
+
+ return latents, num_latent_frames, condition, prefix_video_latents_frames
+
+ # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix
+ def generate_timestep_matrix(
+ self,
+ num_latent_frames: int,
+ step_template: torch.Tensor,
+ base_num_latent_frames: int,
+ ar_step: int = 5,
+ num_pre_ready: int = 0,
+ causal_block_size: int = 1,
+ shrink_interval_with_mask: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
+ """
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
+
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
+ - All frames are denoised simultaneously at each timestep
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
+ - Simpler but may have less temporal consistency for long videos
+
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
+ - Frames are grouped into causal blocks and processed block/chunk-wise
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
+ - Creates stronger temporal dependencies and better consistency
+
+ Args:
+ num_latent_frames (int): Total number of latent frames to generate
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
+ ar_step (int, optional): Autoregressive step size for temporal lag.
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
+ num_pre_ready (int, optional):
+ Number of frames already denoised (e.g., from prefix in a video2video task).
+ Defaults to 0.
+ causal_block_size (int, optional): Number of frames processed as a causal block.
+ Defaults to 1.
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
+ Defaults to False.
+
+ Returns:
+ tuple containing:
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
+ [num_iterations, num_latent_frames]
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
+ num_latent_frames]
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
+ [num_iterations, num_latent_frames]
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
+
+ Raises:
+ ValueError: If ar_step is too small for the given configuration
+ """
+ # Initialize lists to store the scheduling matrices and metadata
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
+
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
+ num_iterations = len(step_template) + 1
+
+ # Convert frame counts to block counts for causal processing
+ # Each block contains causal_block_size frames that are processed together
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
+ num_blocks = num_latent_frames // causal_block_size
+ base_num_blocks = base_num_latent_frames // causal_block_size
+
+ # Validate ar_step is sufficient for the given configuration
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
+ if base_num_blocks < num_blocks:
+ min_ar_step = len(step_template) / base_num_blocks
+ if ar_step < min_ar_step:
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
+
+ # Extend step_template with boundary values for easier indexing
+ # 999: dummy value for counter starting from 1
+ # 0: final timestep (completely denoised)
+ step_template = torch.cat(
+ [
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
+ step_template.long(),
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
+ ]
+ )
+
+ # Initialize the previous row state (tracks denoising progress for each block)
+ # 0 means not started, num_iterations means fully denoised
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
+ if num_pre_ready > 0:
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
+
+ # Main loop: Generate denoising schedule until all frames are fully denoised
+ while not torch.all(pre_row >= (num_iterations - 1)):
+ # Create new row representing the next denoising step
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Apply diffusion forcing logic for each block
+ for i in range(num_blocks):
+ if i == 0 or pre_row[i - 1] >= (
+ num_iterations - 1
+ ): # the first frame or the last frame is completely denoised
+ new_row[i] = pre_row[i] + 1
+ else:
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
+ # This creates the "diffusion forcing" staggered pattern
+ new_row[i] = new_row[i - 1] - ar_step
+
+ # Clamp values to valid range [0, num_iterations]
+ new_row = new_row.clamp(0, num_iterations)
+
+ # Create update mask: True for blocks that need denoising update at this iteration
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
+ # Final state example: [False, ..., False, True, True, True, True, True]
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
+
+ # Store the iteration state
+ step_index.append(new_row) # Index into step_template
+ step_matrix.append(step_template[new_row]) # Actual timestep values
+ pre_row = new_row # Update for next iteration
+
+ # For videos longer than model capacity, we process in sliding windows
+ terminal_flag = base_num_blocks
+
+ # Optional optimization: shrink interval based on first update mask
+ if shrink_interval_with_mask:
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
+ update_mask = update_mask[0]
+ update_mask_idx = idx_sequence[update_mask]
+ last_update_idx = update_mask_idx[-1].item()
+ terminal_flag = last_update_idx + 1
+
+ # Each interval defines which frames to process in the current forward pass
+ for curr_mask in update_mask:
+ # Extend terminal flag if current mask has updates beyond current terminal
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
+ terminal_flag += 1
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
+
+ # Convert lists to tensors for efficient processing
+ step_update_mask = torch.stack(update_mask, dim=0)
+ step_index = torch.stack(step_index, dim=0)
+ step_matrix = torch.stack(step_matrix, dim=0)
+
+ # Each block's schedule is replicated to all frames within that block
+ if causal_block_size > 1:
+ # Expand each block to causal_block_size frames
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ # Scale intervals from block-level to frame-level
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
+
+ return step_matrix, step_index, step_update_mask, valid_interval
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ overlap_history: Optional[int] = None,
+ addnoise_condition: float = 0,
+ base_num_frames: int = 97,
+ ar_step: int = 0,
+ causal_block_size: Optional[int] = None,
+ fps: int = 24,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `544`):
+ The height of the generated video.
+ width (`int`, defaults to `960`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `97`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ last_image (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+ overlap_history (`int`, *optional*, defaults to `None`):
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
+ addnoise_condition (`float`, *optional*, defaults to `0`):
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
+ ones, but it is recommended to not exceed 50.
+ base_num_frames (`int`, *optional*, defaults to `97`):
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
+ ar_step (`int`, *optional*, defaults to `0`):
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
+ inference may improve the instruction following and visual consistent performance.
+ causal_block_size (`int`, *optional*, defaults to `None`):
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
+ 0)
+ fps (`int`, *optional*, defaults to `24`):
+ Frame rate of the generated video
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ overlap_history,
+ num_frames,
+ base_num_frames,
+ )
+
+ if addnoise_condition > 60:
+ logger.warning(
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ if causal_block_size is None:
+ causal_block_size = self.transformer.config.num_frame_per_block
+ else:
+ self.transformer._set_ar_attention(causal_block_size)
+
+ fps_embeds = [fps] * prompt_embeds.shape[0]
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
+
+ # Determine if we're doing long video generation
+ is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames
+ # Initialize accumulated_latents to store all latents in one tensor
+ accumulated_latents = None
+ if is_long_video:
+ # Long video generation setup
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ base_latent_num_frames = (
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
+ if base_num_frames is not None
+ else num_latent_frames
+ )
+ n_iter = (
+ 1
+ + (num_latent_frames - base_latent_num_frames - 1)
+ // (base_latent_num_frames - overlap_history_latent_frames)
+ + 1
+ )
+ else:
+ # Short video generation setup
+ n_iter = 1
+ base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+
+ if last_image is not None:
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ # Loop through iterations (multiple iterations only for long videos)
+ for iter_idx in range(n_iter):
+ if is_long_video:
+ logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...")
+
+ num_channels_latents = self.vae.config.z_dim
+ latents, current_num_latent_frames, condition, prefix_video_latents_frames = self.prepare_latents(
+ image if iter_idx == 0 else None,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents if iter_idx == 0 else None,
+ last_image,
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
+ base_latent_num_frames=base_latent_num_frames if is_long_video else None,
+ causal_block_size=causal_block_size,
+ overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None,
+ long_video_iter=iter_idx if is_long_video else None,
+ )
+
+ if iter_idx == 0:
+ latents[:, :, :prefix_video_latents_frames, :, :] = condition[: (condition.shape[0] + 1) // 2].to(
+ transformer_dtype
+ )
+ else:
+ latents[:, :, :prefix_video_latents_frames, :, :] = condition.to(transformer_dtype)
+
+ if iter_idx == 0 and last_image is not None:
+ end_video_latents = condition[condition.shape[0] // 2 :].to(transformer_dtype)
+
+ if last_image is not None and iter_idx + 1 == n_iter:
+ latents = torch.cat([latents, end_video_latents], dim=2)
+ base_latent_num_frames += prefix_video_latents_frames
+ current_num_latent_frames += prefix_video_latents_frames
+
+ # 4. Prepare sample schedulers and timestep matrix
+ sample_schedulers = []
+ for _ in range(current_num_latent_frames):
+ sample_scheduler = deepcopy(self.scheduler)
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
+ sample_schedulers.append(sample_scheduler)
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
+ current_num_latent_frames,
+ timesteps,
+ base_latent_num_frames,
+ ar_step,
+ prefix_video_latents_frames,
+ causal_block_size,
+ )
+
+ if last_image is not None and iter_idx + 1 == n_iter:
+ step_matrix[:, -prefix_video_latents_frames:] = 0
+ step_update_mask[:, -prefix_video_latents_frames:] = False
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(step_matrix)
+
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
+ for i, t in enumerate(step_matrix):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ valid_interval_start, valid_interval_end = valid_interval[i]
+ latent_model_input = (
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
+ )
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
+
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
+ noise_factor = 0.001 * addnoise_condition
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ * (1.0 - noise_factor)
+ + torch.randn_like(
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ )
+ * noise_factor
+ )
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ update_mask_i = step_update_mask[i]
+ for idx in range(valid_interval_start, valid_interval_end):
+ if update_mask_i[idx].item():
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
+ noise_pred[:, :, idx - valid_interval_start, :, :],
+ t[idx],
+ latents[:, :, idx, :, :],
+ return_dict=False,
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(step_matrix) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # Handle latent accumulation for long videos or use the current latents for short videos
+ if is_long_video:
+ if accumulated_latents is None:
+ accumulated_latents = latents
+ else:
+ # Keep overlap frames for conditioning but don't include them in final output
+ accumulated_latents = torch.cat(
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]],
+ dim=2,
+ )
+
+ if is_long_video:
+ latents = accumulated_latents
+
+ self._current_timestep = None
+
+ # Final decoding step - convert latents to pixels
+ if not output_type == "latent":
+ if last_image is not None:
+ latents = latents[:, :, :-prefix_video_latents_frames, :, :].to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py
new file mode 100644
index 000000000000..6fedfc795a40
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py
@@ -0,0 +1,1063 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import math
+import re
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2DiffusionForcingVideoToVideoPipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ar_step=0, # Controls asynchronous inference (0 for synchronous mode)
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
+ ... addnoise_condition=20, # Improves consistency in long video generation
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ """
+ Pipeline for Video-to-Video (v2v) generation using SkyReels-V2 with diffusion forcing.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
+
+ Args:
+ tokenizer ([`AutoTokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`UMT5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ video=None,
+ latents=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ overlap_history=None,
+ num_frames=None,
+ base_num_frames=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` should be provided")
+
+ if num_frames > base_num_frames and overlap_history is None:
+ raise ValueError(
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
+ )
+
+ def prepare_latents(
+ self,
+ video: torch.Tensor,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 97,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ video_latents: Optional[torch.Tensor] = None,
+ base_latent_num_frames: Optional[int] = None,
+ overlap_history: Optional[int] = None,
+ causal_block_size: Optional[int] = None,
+ overlap_history_latent_frames: Optional[int] = None,
+ long_video_iter: Optional[int] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.shape[2]
+ )
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ if long_video_iter == 0:
+ prefix_video_latents = [
+ retrieve_latents(
+ self.vae.encode(
+ vid.unsqueeze(0)[:, :, -overlap_history:] if vid.dim() == 4 else vid[:, :, -overlap_history:]
+ ),
+ sample_mode="argmax",
+ )
+ for vid in video
+ ]
+ prefix_video_latents = torch.cat(prefix_video_latents, dim=0).to(dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(device, self.vae.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, self.vae.dtype
+ )
+ prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std
+ else:
+ prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:]
+
+ if prefix_video_latents.shape[2] % causal_block_size != 0:
+ truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size
+ logger.warning(
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
+ )
+ prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents]
+ prefix_video_latents_frames = prefix_video_latents.shape[2]
+
+ finished_frame_num = (
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames) + overlap_history_latent_frames
+ )
+ left_frame_num = num_latent_frames - finished_frame_num
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ latent_height,
+ latent_width,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames
+
+ # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix
+ def generate_timestep_matrix(
+ self,
+ num_latent_frames: int,
+ step_template: torch.Tensor,
+ base_num_latent_frames: int,
+ ar_step: int = 5,
+ num_pre_ready: int = 0,
+ causal_block_size: int = 1,
+ shrink_interval_with_mask: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
+ """
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
+
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
+ - All frames are denoised simultaneously at each timestep
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
+ - Simpler but may have less temporal consistency for long videos
+
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
+ - Frames are grouped into causal blocks and processed block/chunk-wise
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
+ - Creates stronger temporal dependencies and better consistency
+
+ Args:
+ num_latent_frames (int): Total number of latent frames to generate
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
+ ar_step (int, optional): Autoregressive step size for temporal lag.
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
+ num_pre_ready (int, optional):
+ Number of frames already denoised (e.g., from prefix in a video2video task).
+ Defaults to 0.
+ causal_block_size (int, optional): Number of frames processed as a causal block.
+ Defaults to 1.
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
+ Defaults to False.
+
+ Returns:
+ tuple containing:
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
+ [num_iterations, num_latent_frames]
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
+ num_latent_frames]
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
+ [num_iterations, num_latent_frames]
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
+
+ Raises:
+ ValueError: If ar_step is too small for the given configuration
+ """
+ # Initialize lists to store the scheduling matrices and metadata
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
+
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
+ num_iterations = len(step_template) + 1
+
+ # Convert frame counts to block counts for causal processing
+ # Each block contains causal_block_size frames that are processed together
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
+ num_blocks = num_latent_frames // causal_block_size
+ base_num_blocks = base_num_latent_frames // causal_block_size
+
+ # Validate ar_step is sufficient for the given configuration
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
+ if base_num_blocks < num_blocks:
+ min_ar_step = len(step_template) / base_num_blocks
+ if ar_step < min_ar_step:
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
+
+ # Extend step_template with boundary values for easier indexing
+ # 999: dummy value for counter starting from 1
+ # 0: final timestep (completely denoised)
+ step_template = torch.cat(
+ [
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
+ step_template.long(),
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
+ ]
+ )
+
+ # Initialize the previous row state (tracks denoising progress for each block)
+ # 0 means not started, num_iterations means fully denoised
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
+ if num_pre_ready > 0:
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
+
+ # Main loop: Generate denoising schedule until all frames are fully denoised
+ while not torch.all(pre_row >= (num_iterations - 1)):
+ # Create new row representing the next denoising step
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Apply diffusion forcing logic for each block
+ for i in range(num_blocks):
+ if i == 0 or pre_row[i - 1] >= (
+ num_iterations - 1
+ ): # the first frame or the last frame is completely denoised
+ new_row[i] = pre_row[i] + 1
+ else:
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
+ # This creates the "diffusion forcing" staggered pattern
+ new_row[i] = new_row[i - 1] - ar_step
+
+ # Clamp values to valid range [0, num_iterations]
+ new_row = new_row.clamp(0, num_iterations)
+
+ # Create update mask: True for blocks that need denoising update at this iteration
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
+ # Final state example: [False, ..., False, True, True, True, True, True]
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
+
+ # Store the iteration state
+ step_index.append(new_row) # Index into step_template
+ step_matrix.append(step_template[new_row]) # Actual timestep values
+ pre_row = new_row # Update for next iteration
+
+ # For videos longer than model capacity, we process in sliding windows
+ terminal_flag = base_num_blocks
+
+ # Optional optimization: shrink interval based on first update mask
+ if shrink_interval_with_mask:
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
+ update_mask = update_mask[0]
+ update_mask_idx = idx_sequence[update_mask]
+ last_update_idx = update_mask_idx[-1].item()
+ terminal_flag = last_update_idx + 1
+
+ # Each interval defines which frames to process in the current forward pass
+ for curr_mask in update_mask:
+ # Extend terminal flag if current mask has updates beyond current terminal
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
+ terminal_flag += 1
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
+
+ # Convert lists to tensors for efficient processing
+ step_update_mask = torch.stack(update_mask, dim=0)
+ step_index = torch.stack(step_index, dim=0)
+ step_matrix = torch.stack(step_matrix, dim=0)
+
+ # Each block's schedule is replicated to all frames within that block
+ if causal_block_size > 1:
+ # Expand each block to causal_block_size frames
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ # Scale intervals from block-level to frame-level
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
+
+ return step_matrix, step_index, step_update_mask, valid_interval
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: List[Image.Image],
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 120,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ overlap_history: Optional[int] = None,
+ addnoise_condition: float = 0,
+ base_num_frames: int = 97,
+ ar_step: int = 0,
+ causal_block_size: Optional[int] = None,
+ fps: int = 24,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[Image.Image]`):
+ The video to guide the video generation.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `544`):
+ The height of the generated video.
+ width (`int`, defaults to `960`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `120`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+ overlap_history (`int`, *optional*, defaults to `None`):
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
+ addnoise_condition (`float`, *optional*, defaults to `0`):
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
+ ones, but it is recommended to not exceed 50.
+ base_num_frames (`int`, *optional*, defaults to `97`):
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
+ ar_step (`int`, *optional*, defaults to `0`):
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
+ inference may improve the instruction following and visual consistent performance.
+ causal_block_size (`int`, *optional*, defaults to `None`):
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
+ 0)
+ fps (`int`, *optional*, defaults to `24`):
+ Frame rate of the generated video
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ video,
+ latents,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ overlap_history,
+ num_frames,
+ base_num_frames,
+ )
+
+ if addnoise_condition > 60:
+ logger.warning(
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ if latents is None:
+ video_original = self.video_processor.preprocess_video(video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ if causal_block_size is None:
+ causal_block_size = self.transformer.config.num_frame_per_block
+ else:
+ self.transformer._set_ar_attention(causal_block_size)
+
+ fps_embeds = [fps] * prompt_embeds.shape[0]
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
+
+ # Long video generation
+ accumulated_latents = None
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ base_latent_num_frames = (
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
+ if base_num_frames is not None
+ else num_latent_frames
+ )
+ n_iter = (
+ 1
+ + (num_latent_frames - base_latent_num_frames - 1)
+ // (base_latent_num_frames - overlap_history_latent_frames)
+ + 1
+ )
+ for long_video_iter in range(n_iter):
+ logger.debug(f"Processing iteration {long_video_iter + 1}/{n_iter} for long video generation...")
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = (
+ self.prepare_latents(
+ video_original,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents if long_video_iter == 0 else None,
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
+ overlap_history=overlap_history,
+ base_latent_num_frames=base_latent_num_frames,
+ causal_block_size=causal_block_size,
+ overlap_history_latent_frames=overlap_history_latent_frames,
+ long_video_iter=long_video_iter,
+ )
+ )
+
+ if prefix_video_latents_frames > 0:
+ latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype)
+
+ # 4. Prepare sample schedulers and timestep matrix
+ sample_schedulers = []
+ for _ in range(current_num_latent_frames):
+ sample_scheduler = deepcopy(self.scheduler)
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
+ sample_schedulers.append(sample_scheduler)
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
+ current_num_latent_frames,
+ timesteps,
+ current_num_latent_frames,
+ ar_step,
+ prefix_video_latents_frames,
+ causal_block_size,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(step_matrix)
+
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
+ for i, t in enumerate(step_matrix):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ valid_interval_start, valid_interval_end = valid_interval[i]
+ latent_model_input = (
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
+ )
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
+
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
+ noise_factor = 0.001 * addnoise_condition
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ * (1.0 - noise_factor)
+ + torch.randn_like(
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ )
+ * noise_factor
+ )
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ update_mask_i = step_update_mask[i]
+ for idx in range(valid_interval_start, valid_interval_end):
+ if update_mask_i[idx].item():
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
+ noise_pred[:, :, idx - valid_interval_start, :, :],
+ t[idx],
+ latents[:, :, idx, :, :],
+ return_dict=False,
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(step_matrix) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if accumulated_latents is None:
+ accumulated_latents = latents
+ else:
+ # Keep overlap frames for conditioning but don't include them in final output
+ accumulated_latents = torch.cat(
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2
+ )
+
+ latents = accumulated_latents
+
+ self._current_timestep = None
+
+ # Final decoding step - convert latents to pixels
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video_generated = self.vae.decode(latents, return_dict=False)[0]
+ video = torch.cat([video_original, video_generated], dim=2)
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
new file mode 100644
index 000000000000..d61b687eadc3
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
@@ -0,0 +1,745 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import regex as re
+import torch
+from transformers import AutoTokenizer, CLIPProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2ImageToVideoPipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+ >>> from PIL import Image
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+ >>> image = Image.open("path/to/image.png")
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ r"""
+ Pipeline for Image-to-Video (i2v) generation using SkyReels-V2.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ image_encoder: CLIPVisionModelWithProjection,
+ image_processor: CLIPProcessor,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.image_processor = image_processor
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image
+ def encode_image(
+ self,
+ image: PipelineImageInput,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ image = image.unsqueeze(2)
+ if last_image is None:
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
+ )
+ else:
+ last_image = last_image.unsqueeze(2)
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
+ dim=2,
+ )
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+
+ latent_condition = latent_condition.to(dtype)
+ latent_condition = (latent_condition - latents_mean) * latents_std
+
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
+
+ if last_image is None:
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ else:
+ mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
+
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `544`):
+ The height of the generated video.
+ width (`int`, defaults to `960`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `97`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Encode image embedding
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ if image_embeds is None:
+ if last_image is None:
+ image_embeds = self.encode_image(image, device)
+ else:
+ image_embeds = self.encode_image([image, last_image], device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+ if last_image is not None:
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+ latents, condition = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ last_image,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
index b8f8a705de21..07b382dfc49f 100644
--- a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
index 5d773b614a5c..b7faf097ab0d 100644
--- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,11 +25,7 @@
from ...models import AutoencoderOobleck, StableAudioDiTModel
from ...models.embeddings import get_1d_rotary_pos_embed
from ...schedulers import EDMDPMSolverMultistepScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_stable_audio import StableAudioProjectionModel
@@ -134,6 +130,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
@@ -142,6 +144,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def encode_prompt(
@@ -306,7 +314,7 @@ def encode_duration(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -526,8 +534,8 @@ def __call__(
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
The number of waveforms to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -584,7 +592,7 @@ def __call__(
if audio_end_in_s - audio_start_in_s > max_audio_length_in_s:
raise ValueError(
- f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
+ f"The total audio length requested ({audio_end_in_s - audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'."
)
waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate)
@@ -616,7 +624,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
index 38f1c4314e4f..a6a60ad94be6 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
@@ -55,7 +55,7 @@
"""
-class StableCascadeDecoderPipeline(DiffusionPipeline):
+class StableCascadeDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for generating images from the Stable Cascade model.
@@ -79,6 +79,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
width=int(24*10.67)=256 in order to match the training conditions.
"""
+ _last_supported_version = "0.35.2"
+
unet_name = "decoder"
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
@@ -332,11 +334,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 0.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
- `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
- linked to the text `prompt`, usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
+ setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
@@ -362,7 +364,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -524,9 +526,9 @@ def __call__(
latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1)
if output_type == "np":
- images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
+ images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesn't work
elif output_type == "pil":
- images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
+ images = images.permute(0, 2, 3, 1).cpu().float().numpy() # float() as bfloat16-> numpy doesn't work
images = self.numpy_to_pil(images)
else:
images = latents
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
index 28a74ab83733..838b93faaa0c 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_torch_version, replace_example_docstring
-from ..pipeline_utils import DiffusionPipeline
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
@@ -42,7 +42,7 @@
"""
-class StableCascadeCombinedPipeline(DiffusionPipeline):
+class StableCascadeCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Combined Pipeline for text-to-image generation using Stable Cascade.
@@ -74,6 +74,8 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
"""
+ _last_supported_version = "0.35.2"
+
_load_connected_pipes = True
_optional_components = ["prior_feature_extractor", "prior_image_encoder"]
@@ -125,7 +127,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
@@ -135,7 +137,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
@@ -212,11 +214,11 @@ def __call__(
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
- `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
- to the text `prompt`, usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `prior_guidance_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
+ setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized
@@ -226,18 +228,18 @@ def __call__(
the expense of slower inference. For more specific timestep spacing, you can pass customized
`timesteps`
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
index 241c454e103e..29ad8b5429d7 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,7 +25,7 @@
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
if is_torch_xla_available():
@@ -77,7 +77,7 @@ class StableCascadePriorPipelineOutput(BaseOutput):
negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
-class StableCascadePriorPipeline(DiffusionPipeline):
+class StableCascadePriorPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for generating image prior for Stable Cascade.
@@ -103,6 +103,8 @@ class StableCascadePriorPipeline(DiffusionPipeline):
Default resolution for multiple images generated.
"""
+ _last_supported_version = "0.35.2"
+
unet_name = "prior"
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "image_encoder->text_encoder->prior"
@@ -409,11 +411,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 8.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
- `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
- linked to the text `prompt`, usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
+ setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
@@ -442,7 +444,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -626,11 +628,11 @@ def __call__(
self.maybe_free_model_hooks()
if output_type == "np":
- latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
- prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesnt work
+ latents = latents.cpu().float().numpy() # float() as bfloat16-> numpy doesn't work
+ prompt_embeds = prompt_embeds.cpu().float().numpy() # float() as bfloat16-> numpy doesn't work
negative_prompt_embeds = (
negative_prompt_embeds.cpu().float().numpy() if negative_prompt_embeds is not None else None
- ) # float() as bfloat16-> numpy doesnt work
+ ) # float() as bfloat16-> numpy doesn't work
if not return_dict:
return (
diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md
index 5b229fddadd5..164baeb0a4d3 100644
--- a/src/diffusers/pipelines/stable_diffusion/README.md
+++ b/src/diffusers/pipelines/stable_diffusion/README.md
@@ -10,7 +10,7 @@ The summary of the model is the following:
## Tips:
-- Stable Diffusion has the same architecture as [Latent Diffusion](https://arxiv.org/abs/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model.
+- Stable Diffusion has the same architecture as [Latent Diffusion](https://huggingface.co/papers/2112.10752) but uses a frozen CLIP Text Encoder instead of training the text encoder jointly with the diffusion model.
- An in-detail explanation of the Stable Diffusion model can be found under [Stable Diffusion with 🧨 Diffusers](https://huggingface.co/blog/stable_diffusion).
- If you don't want to rely on the Hugging Face Hub and having to pass a authentication token, you can
download the weights with `git lfs install; git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5` and instead pass the local path to the cloned folder to `from_pretrained` as shown below.
@@ -28,7 +28,7 @@ download the weights with `git lfs install; git clone https://huggingface.co/sta
### Using Stable Diffusion without being logged into the Hub.
-If you want to download the model weights using a single Python line, you need to be logged in via `huggingface-cli login`.
+If you want to download the model weights using a single Python line, you need to be logged in via `hf auth login`.
```python
from diffusers import DiffusionPipeline
@@ -54,7 +54,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
### Text-to-Image with default PLMS scheduler
```python
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
@@ -69,7 +69,7 @@ image.save("astronaut_rides_horse.png")
### Text-to-Image with DDIM scheduler
```python
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
@@ -88,7 +88,7 @@ image.save("astronaut_rides_horse.png")
### Text-to-Image with K-LMS scheduler
```python
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
@@ -118,7 +118,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the scheduler. CycleDiffusion only supports stochastic schedulers.
# load the pipeline
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py
index 8ce08aec66ca..b05a3ce2a7c9 100644
--- a/src/diffusers/pipelines/stable_diffusion/__init__.py
+++ b/src/diffusers/pipelines/stable_diffusion/__init__.py
@@ -30,18 +30,11 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["clip_image_project_model"] = ["CLIPImageProjection"]
- _import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"]
_import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"]
- _import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
- _import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
- _import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
_import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"]
- _import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]
_import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"]
_import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"]
- _import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
- _import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
diff --git a/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py b/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py
index 71f9d9714e6b..30dd90242d07 100644
--- a/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py
+++ b/src/diffusers/pipelines/stable_diffusion/clip_image_project_model.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved.
+# Copyright 2025 The GLIGEN Authors and HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
index 4cc4eabd4a40..6c0221d2092a 100644
--- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
@@ -52,6 +52,8 @@
UnCLIPScheduler,
)
from ...utils import is_accelerate_available, logging
+from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT
+from ...utils.torch_utils import get_device
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
from ..pipeline_utils import DiffusionPipeline
@@ -349,8 +351,14 @@ def create_vae_diffusers_config(original_config, image_size: int):
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+ down_block_types = [
+ "DownEncoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D"
+ for i, _ in enumerate(block_out_channels)
+ ]
+ up_block_types = [
+ "UpDecoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D"
+ for i, _ in enumerate(block_out_channels)
+ ][::-1]
config = {
"sample_size": image_size,
@@ -1265,7 +1273,7 @@ def download_from_original_stable_diffusion_ckpt(
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
else:
if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = get_device()
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
else:
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
@@ -1324,7 +1332,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None:
- original_config_file = BytesIO(requests.get(config_url).content)
+ original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
@@ -1835,7 +1843,7 @@ def download_controlnet_from_original_ckpt(
checkpoint[key] = f.get_tensor(key)
else:
if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = get_device()
checkpoint = torch.load(checkpoint_path, map_location=device)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
index eaeb5f809c47..6befe77aa4b1 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -349,12 +349,8 @@ def __call__(
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
index c2d918156084..81656beba7e1 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -389,12 +389,8 @@ def __call__(
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
Examples:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index abcba926160a..5938fe232a71 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -103,11 +103,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
r"""
Flax-based pipeline for text-guided image inpainting using Stable Diffusion.
-
-
- 🧪 This is an experimental feature!
-
-
+ > [!WARNING] > 🧪 This is an experimental feature!
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -335,7 +331,7 @@ def _generate(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
@@ -435,12 +431,8 @@ def __call__(
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
index 9917276e0a1f..6ebe0986a1ab 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -294,11 +294,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
@@ -306,14 +306,14 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
One or a list of [numpy generator(s)](TODO) to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -359,7 +359,7 @@ def __call__(
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -383,11 +383,12 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
index 92c82d61b8f2..d63bf3bf4564 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -348,19 +348,19 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
prompt_embeds (`np.ndarray`, *optional*):
@@ -414,7 +414,7 @@ def __call__(
image = preprocess(image).cpu().numpy()
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -470,7 +470,7 @@ def __call__(
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index ddd2e27dedaf..158bcabbebfd 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -360,25 +360,25 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -427,7 +427,7 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -475,7 +475,7 @@ def __call__(
"Incorrect configuration settings! The config of `pipeline.unet` expects"
f" {unet_input_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
@@ -483,11 +483,11 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
+ latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
index ef84cdd38b6d..a765163175a2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -378,11 +378,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter will be modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
noise_level (`float`, defaults to 0.2):
Deteremines the amount of noise to add to the initial image before performing upscaling.
negative_prompt (`str` or `List[str]`, *optional*):
@@ -391,14 +391,14 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`np.random.RandomState`, *optional*):
A np.random.RandomState to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -450,7 +450,7 @@ def __call__(
generator = np.random
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -481,7 +481,7 @@ def __call__(
timesteps = self.scheduler.timesteps
# Scale the initial noise by the standard deviation required by the scheduler
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
+ latents = latents * self.scheduler.init_noise_sigma
# 5. Add noise to image
noise_level = np.array([noise_level]).astype(np.int64)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 6e93c34929de..cb97f18efeff 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -70,7 +70,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -608,7 +608,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -757,7 +757,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -836,8 +836,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -867,7 +867,7 @@ def __call__(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -1034,7 +1034,8 @@ def __call__(
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ if hasattr(self.scheduler, "scale_model_input"):
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
@@ -1053,7 +1054,7 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index f158c41cac53..e957c6661f87 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -414,7 +414,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -617,7 +617,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -684,8 +684,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
index e0268065a415..d47e2f0593dd 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -197,7 +197,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -293,8 +293,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -359,7 +359,7 @@ def __call__(
batch_size = image.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 901dcd6db012..95d3ab06f02a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -639,7 +639,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -837,7 +837,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -922,8 +922,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 6f4e7f358952..148d7386a732 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -581,7 +581,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -660,7 +660,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -668,7 +668,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -859,7 +859,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -968,8 +968,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1226,7 +1226,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index 7857bc58a8ad..843d25d67c10 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The InstructPix2Pix Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -221,8 +221,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -401,7 +401,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
@@ -742,7 +742,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -910,7 +910,7 @@ def num_timesteps(self):
return self._num_timesteps
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index c6967bc393b5..66d5ffa6b849 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -433,8 +433,8 @@ def __call__(
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -516,7 +516,7 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -600,7 +600,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index dae4540ebe00..f13cdc67073f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,10 +24,6 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import (
@@ -404,7 +400,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -529,21 +525,12 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
return latents
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@torch.no_grad()
def __call__(
@@ -587,8 +574,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -677,7 +664,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -740,7 +727,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
- f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index be01e0acbf18..a134244e3ee4 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -492,7 +492,7 @@ def decode_latents(self, latents):
def prepare_prior_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the prior_scheduler step, since not all prior_schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other prior_schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.prior_scheduler.step).parameters.keys())
@@ -510,7 +510,7 @@ def prepare_prior_extra_step_kwargs(self, generator, eta):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -693,8 +693,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -774,7 +774,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
prior_do_classifier_free_guidance = prior_guidance_scale > 1.0
@@ -842,7 +842,7 @@ def __call__(
# done prior
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
index eac9945ff349..abb4cc3a05d5 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -454,7 +454,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -674,8 +674,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -753,7 +753,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py
index 3a0e86409e4a..65daafe01237 100644
--- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py
+++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
-from ...utils import logging
+from ...utils import is_transformers_version, logging
logger = logging.get_logger(__name__)
@@ -46,6 +46,9 @@ def __init__(self, config: CLIPConfig):
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
+ # Model requires post_init after transformers v4.57.3
+ if is_transformers_version(">", "4.57.3"):
+ self.post_init()
@torch.no_grad()
def forward(self, clip_input, images):
diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
index 571a4f2d7710..55d5023b6869 100644
--- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
+++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py
index 3fc6b3a3f8b0..ffd66792fe46 100644
--- a/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py
+++ b/src/diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index 4618d384cbd7..660d9801df56 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
+# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@
T5TokenizerFast,
)
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
@@ -184,7 +185,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
def __init__(
self,
@@ -247,7 +248,7 @@ def _get_t5_prompt_embeds(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -335,7 +336,7 @@ def _get_clip_prompt_embeds(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -673,7 +674,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -830,11 +831,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -853,7 +854,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -923,6 +924,9 @@ def __call__(
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
@@ -1109,10 +1113,7 @@ def __call__(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- negative_pooled_prompt_embeds = callback_outputs.pop(
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
- )
+ pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
index 19bdc9792e23..9b11bc8781e7 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -272,7 +272,7 @@ def _get_t5_prompt_embeds(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -361,7 +361,7 @@ def _get_clip_prompt_embeds(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -734,7 +734,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -886,11 +886,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -909,7 +909,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index c69fb90a4c5e..b947cbff0914 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -278,7 +278,7 @@ def _get_t5_prompt_embeds(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -367,7 +367,7 @@ def _get_clip_prompt_embeds(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -822,7 +822,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -984,7 +984,7 @@ def __call__(
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
@@ -1010,11 +1010,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1033,7 +1033,7 @@ def __call__(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1247,7 +1247,7 @@ def __call__(
# match the inpainting pipeline and will be updated with input + mask inpainting model later
if num_channels_transformer == 33:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if (
@@ -1258,7 +1258,7 @@ def __call__(
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input."
)
elif num_channels_transformer != 16:
diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
index 351b146fb423..a1ff99b6aa34 100644
--- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,7 +37,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -179,7 +179,9 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
return hidden_states
-class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin):
+class StableDiffusionAttendAndExcitePipeline(
+ DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin
+):
r"""
Pipeline for text-to-image generation using Stable Diffusion and Attend-and-Excite.
@@ -209,6 +211,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
+ _last_supported_version = "0.33.1"
+
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -502,7 +506,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -794,8 +798,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -873,7 +877,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -1047,7 +1051,7 @@ def __call__(
class GaussianSmoothing(torch.nn.Module):
"""
Arguments:
- Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input
+ Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed separately for each channel in the input
using a depthwise convolution.
channels (int, sequence): Number of channels of the input tensors. Output will
have this number of channels as well.
diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
index 4b999662a6e7..65c25ffbe492 100644
--- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
+++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
@@ -1,4 +1,4 @@
-# Copyright 2024 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -40,7 +40,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -242,14 +242,14 @@ def preprocess_mask(mask, batch_size: int = 1):
class StableDiffusionDiffEditPipeline(
- DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
+ DeprecatedPipelineMixin,
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
):
r"""
-
-
- This is an experimental feature!
-
-
+ > [!WARNING] > This is an experimental feature!
Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit.
@@ -282,6 +282,8 @@ class StableDiffusionDiffEditPipeline(
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
+ _last_supported_version = "0.33.1"
+
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -618,7 +620,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -969,7 +971,7 @@ def generate_mask(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -1176,7 +1178,7 @@ def invert(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -1349,8 +1351,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1422,7 +1424,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
index 4bbb93e44a83..78b026684cfa 100644
--- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
+++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved.
+# Copyright 2025 The GLIGEN Authors and HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -108,7 +108,7 @@
"""
-class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
+class StableDiffusionGLIGENPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN).
@@ -135,6 +135,8 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
+ _last_supported_version = "0.33.1"
+
_optional_components = ["safety_checker", "feature_extractor"]
model_cpu_offload_seq = "text_encoder->unet->vae"
_exclude_from_cpu_offload = ["safety_checker"]
@@ -415,7 +417,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -589,7 +591,7 @@ def __call__(
`gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image.
gligen_scheduled_sampling_beta (`float`, defaults to 0.3):
Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image
- Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for
+ Generation](https://huggingface.co/papers/2301.07093). Scheduled Sampling factor is only varied for
scheduled sampling during inference for improved quality and controllability.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
@@ -597,8 +599,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -628,7 +630,7 @@ def __call__(
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
@@ -669,7 +671,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
index 86ef01784057..05cbad139d92 100644
--- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
+++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The GLIGEN Authors and HuggingFace Team. All rights reserved.
+# Copyright 2025 The GLIGEN Authors and HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -41,7 +41,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -160,7 +160,7 @@
"""
-class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionMixin):
+class StableDiffusionGLIGENTextImagePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN).
@@ -175,7 +175,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
processor ([`~transformers.CLIPProcessor`]):
- A `CLIPProcessor` to procces reference image.
+ A `CLIPProcessor` to process reference image.
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
Frozen image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
image_project ([`CLIPImageProjection`]):
@@ -193,6 +193,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
+ _last_supported_version = "0.33.1"
+
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -447,7 +449,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -775,7 +777,7 @@ def __call__(
`gligen_phrases`. Otherwise, it is treated as a generation task on a blank input image.
gligen_scheduled_sampling_beta (`float`, defaults to 0.3):
Scheduled Sampling factor from [GLIGEN: Open-Set Grounded Text-to-Image
- Generation](https://arxiv.org/pdf/2301.07093.pdf). Scheduled Sampling factor is only varied for
+ Generation](https://huggingface.co/papers/2301.07093). Scheduled Sampling factor is only varied for
scheduled sampling during inference for improved quality and controllability.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
@@ -783,8 +785,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -854,7 +856,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
index 1f29f577f8e0..feebd6adf8f8 100755
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -42,7 +42,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
@@ -64,7 +64,11 @@ def apply_model(self, *args, **kwargs):
class StableDiffusionKDiffusionPipeline(
- DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
+ DeprecatedPipelineMixin,
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -77,11 +81,7 @@ class StableDiffusionKDiffusionPipeline(
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
-
-
- This is an experimental pipeline and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental pipeline and is likely to change in the future.
Args:
vae ([`AutoencoderKL`]):
@@ -105,6 +105,8 @@ class StableDiffusionKDiffusionPipeline(
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
+ _last_supported_version = "0.33.1"
+
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
@@ -123,7 +125,7 @@ def __init__(
super().__init__()
logger.info(
- f"{self.__class__} is an experimntal pipeline and is likely to change in the future. We recommend to use"
+ f"{self.__class__} is an experimental pipeline and is likely to change in the future. We recommend to use"
" this pipeline for fast experimentation / iteration if needed, but advice to rely on existing pipelines"
" as defined in https://huggingface.co/docs/diffusers/api/schedulers#implemented-schedulers for"
" production settings."
@@ -513,11 +515,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
@@ -525,15 +527,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -588,7 +590,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = True
if guidance_scale <= 1.0:
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
index c7c5bd9cff67..f9a8abfcc568 100644
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,22 +33,18 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- FusedAttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -88,6 +84,7 @@ def apply_model(self, *args, **kwargs):
class StableDiffusionXLKDiffusionPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
StableDiffusionMixin,
FromSingleFileMixin,
@@ -95,6 +92,8 @@ class StableDiffusionXLKDiffusionPipeline(
TextualInversionLoaderMixin,
IPAdapterMixin,
):
+ _last_supported_version = "0.33.1"
+
r"""
Pipeline for text-to-image generation using Stable Diffusion XL and k-diffusion.
@@ -542,22 +541,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@property
def guidance_scale(self):
@@ -568,7 +557,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -629,11 +618,11 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -649,7 +638,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
index 702f3eda5816..c32121c88c9b 100644
--- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
+++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved.
+# Copyright 2025 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,7 +37,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -73,7 +73,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -178,6 +178,7 @@ class LDM3DPipelineOutput(BaseOutput):
class StableDiffusionLDM3DPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
@@ -185,6 +186,8 @@ class StableDiffusionLDM3DPipeline(
StableDiffusionLoraLoaderMixin,
FromSingleFileMixin,
):
+ _last_supported_version = "0.33.1"
+
r"""
Pipeline for text-to-image and 3D generation using LDM3D.
@@ -573,7 +576,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -723,7 +726,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -800,8 +803,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -988,7 +991,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
index ccee6d47b47a..295095947a12 100644
--- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
@@ -1,4 +1,4 @@
-# Copyright 2024 MultiDiffusion Authors and The HuggingFace Team. All rights reserved."
+# Copyright 2025 MultiDiffusion Authors and The HuggingFace Team. All rights reserved."
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
@@ -33,7 +33,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -73,7 +73,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -156,12 +156,15 @@ def retrieve_timesteps(
class StableDiffusionPanoramaPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
IPAdapterMixin,
):
+ _last_supported_version = "0.33.1"
+
r"""
Pipeline for text-to-image generation using MultiDiffusion.
@@ -587,7 +590,7 @@ def decode_latents_with_padding(self, latents: torch.Tensor, padding: int = 8) -
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -735,8 +738,8 @@ def get_views(
) -> List[Tuple[int, int, int, int]]:
"""
Generates a list of views based on the given parameters. Here, we define the mappings F_i (see Eq. 7 in the
- MultiDiffusion paper https://arxiv.org/abs/2302.08113). If panorama's height/width < window_size, num_blocks of
- height/width should return 1.
+ MultiDiffusion paper https://huggingface.co/papers/2302.08113). If panorama's height/width < window_size,
+ num_blocks of height/width should return 1.
Args:
panorama_height (int): The height of the panorama.
@@ -854,8 +857,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -962,7 +965,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -1054,7 +1057,7 @@ def __call__(
# Here, we iterate through different spatial crops of the latents and denoise them. These
# denoised (latent) crops are then averaged to produce the final latent
# for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
- # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
+ # MultiDiffusion paper for more details: https://huggingface.co/papers/2302.08113
# Batch views denoise
for j, batch_view in enumerate(views_batch):
vb_size = len(batch_view)
@@ -1113,7 +1116,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale
)
@@ -1144,7 +1147,7 @@ def __call__(
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
- # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
+ # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://huggingface.co/papers/2302.08113
latents = torch.where(count > 0, value / count, value)
if callback_on_step_end is not None:
diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
index deae82eb8813..d334107b0703 100644
--- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
+++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
@@ -14,7 +14,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker
@@ -29,7 +29,9 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin):
+class StableDiffusionPipelineSafe(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin):
+ _last_supported_version = "0.33.1"
+
r"""
Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion.
@@ -358,7 +360,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -561,8 +563,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -632,7 +634,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py b/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py
index 338e4c65c500..1f6ad5f2a348 100644
--- a/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py
+++ b/src/diffusers/pipelines/stable_diffusion_safe/safety_checker.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
index e96422073b19..48add535a81d 100644
--- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Susung Hong and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Susung Hong and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -106,8 +106,12 @@ def __call__(
return hidden_states
-# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input
-class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin):
+# Modified to get self-attention guidance scale in this paper (https://huggingface.co/papers/2210.00939) as an input
+class StableDiffusionSAGPipeline(
+ DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin
+):
+ _last_supported_version = "0.33.1"
+
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -476,7 +480,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -616,8 +620,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -681,11 +685,11 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# and `sag_scale` is` `s` of equation (16)
- # of the self-attention guidance paper: https://arxiv.org/pdf/2210.00939.pdf
+ # of the self-attention guidance paper: https://huggingface.co/papers/2210.00939
# `sag_scale = 0` means no self-attention guidance
do_self_attention_guidance = sag_scale > 0.0
@@ -802,7 +806,7 @@ def get_map_size(module, input, output):
if do_self_attention_guidance:
# classifier-free guidance produces two chunks of attention map
# and we only use unconditional one according to equation (25)
- # in https://arxiv.org/pdf/2210.00939.pdf
+ # in https://huggingface.co/papers/2210.00939
if do_classifier_free_guidance:
# DDIM-like prediction of x0
pred_x0 = self.pred_x0(latents, noise_pred_uncond, t)
@@ -876,7 +880,7 @@ def get_map_size(module, input, output):
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def sag_masking(self, original_latents, attn_map, map_size, t, eps):
- # Same masking process as in SAG paper: https://arxiv.org/pdf/2210.00939.pdf
+ # Same masking process as in SAG paper: https://huggingface.co/papers/2210.00939
bh, hw1, hw2 = attn_map.shape
b, latent_channel, latent_h, latent_w = original_latents.shape
h = self.unet.config.attention_head_dim
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
index eb1030f3bb9d..3227fd9a08a4 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index 9c69fe65fbdb..e969d2a21a99 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,11 +33,6 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- FusedAttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -90,7 +85,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -598,7 +593,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -750,22 +745,12 @@ def _get_add_time_ids(
return add_time_ids
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -811,7 +796,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -914,11 +899,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -929,15 +914,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -970,9 +955,10 @@ def __call__(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1229,7 +1215,7 @@ def __call__(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 08d0b44d613d..8d1da8dc102c 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,10 +34,6 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -50,7 +46,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import randn_tensor
+from ...utils.torch_utils import empty_device_cache, randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import StableDiffusionXLPipelineOutput
@@ -93,7 +89,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -544,7 +540,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -704,7 +700,7 @@ def prepare_latents(
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
image = image.to(device=device, dtype=dtype)
@@ -897,21 +893,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -957,7 +944,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1074,11 +1061,11 @@ def __call__(
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refine Image
Quality**](https://huggingface.co/docs/diffusers/using-diffusers/sdxl#refine-image-quality).
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1089,15 +1076,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1130,9 +1117,10 @@ def __call__(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1420,7 +1408,7 @@ def denoising_value_valid(dnv):
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 920caf4d24a1..54a1e311804c 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,10 +35,6 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -104,7 +100,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -648,7 +644,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -741,7 +737,7 @@ def check_inputs(
if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image):
raise ValueError(
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
)
if not isinstance(mask_image, PIL.Image.Image):
raise ValueError(
@@ -749,7 +745,7 @@ def check_inputs(
f" {type(mask_image)}."
)
if output_type != "pil":
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1002,21 +998,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
@@ -1062,7 +1049,7 @@ def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -1208,11 +1195,11 @@ def __call__(
forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1243,15 +1230,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1501,7 +1488,7 @@ def denoising_value_valid(dnv):
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
@@ -1509,7 +1496,7 @@ def denoising_value_valid(dnv):
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
- f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1638,7 +1625,7 @@ def denoising_value_valid(dnv):
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
index aaffe8efa730..5e13362eb3d1 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Harutatsu Akiyama and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Harutatsu Akiyama and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,11 +22,6 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- FusedAttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -104,7 +99,7 @@ def retrieve_latents(
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -427,7 +422,7 @@ def encode_prompt(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -590,22 +585,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -667,11 +652,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
image_guidance_scale (`float`, *optional*, defaults to 1.5):
Image guidance scale is to push the generated image towards the initial image `image`. Image guidance
scale is enabled by setting `image_guidance_scale > 1`. Higher image guidance scale encourages to
@@ -687,15 +672,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -728,9 +713,10 @@ def __call__(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -785,7 +771,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0
@@ -928,7 +914,7 @@ def __call__(
)
if do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
index 8c1af7863e63..6d9053faaec8 100644
--- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
+++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -369,7 +369,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -495,7 +495,7 @@ def __call__(
batch_size = image.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
self._guidance_scale = max_guidance_scale
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index 6cd0e415e129..1ce6987114a7 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TencentARC and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -123,7 +123,7 @@ def _preprocess_adapter_image(image, height, width):
image = torch.cat(image, dim=0)
else:
raise ValueError(
- f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but receive: {image[0].ndim}"
)
return image
@@ -191,7 +191,7 @@ def retrieve_timesteps(
class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
- https://arxiv.org/abs/2302.08453
+ https://huggingface.co/papers/2302.08453
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
@@ -521,7 +521,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -680,7 +680,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -740,11 +740,11 @@ def __call__(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
@@ -752,15 +752,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 5eacb64d01e3..0ea3ba5046cf 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TencentARC and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TencentARC and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,15 +34,12 @@
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ImageProjection, MultiAdapter, T2IAdapter, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -121,7 +118,7 @@ def _preprocess_adapter_image(image, height, width):
image = torch.cat(image, dim=0)
else:
raise ValueError(
- f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but receive: {image[0].ndim}"
)
return image
@@ -131,7 +128,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -223,7 +220,7 @@ class StableDiffusionXLAdapterPipeline(
):
r"""
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
- https://arxiv.org/abs/2302.08453
+ https://huggingface.co/papers/2302.08453
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
@@ -624,7 +621,7 @@ def prepare_ip_adapter_image_embeds(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -779,21 +776,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
def _default_height_width(self, height, width, image):
@@ -859,7 +847,7 @@ def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
@@ -948,11 +936,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -963,15 +951,15 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1010,9 +998,10 @@ def __call__(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1266,7 +1255,7 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
index 5c63d66e3133..3ce7b4d1990f 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,7 +33,7 @@
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput
@@ -68,8 +68,13 @@
class TextToVideoSDPipeline(
- DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
+ DeprecatedPipelineMixin,
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
):
+ _last_supported_version = "0.33.1"
r"""
Pipeline for text-to-video generation.
@@ -349,7 +354,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -489,8 +494,8 @@ def __call__(
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -550,7 +555,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
index 006c7a79ce0d..9d0b7e3dbc32 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,7 @@
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput
@@ -103,8 +103,13 @@ def retrieve_latents(
class VideoToVideoSDPipeline(
- DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
+ DeprecatedPipelineMixin,
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
):
+ _last_supported_version = "0.33.1"
r"""
Pipeline for text-guided video-to-video generation.
@@ -385,7 +390,7 @@ def decode_latents(self, latents):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -553,8 +558,8 @@ def __call__(
The prompt or prompts to guide what to not include in video generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -609,7 +614,7 @@ def __call__(
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
index df85f470a80b..96316f8e91e5 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -23,8 +23,8 @@
scale_lora_layers,
unscale_lora_layers,
)
-from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ...utils.torch_utils import empty_device_cache, randn_tensor
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionSafetyChecker
@@ -296,12 +296,14 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
class TextToVideoZeroPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
FromSingleFileMixin,
):
+ _last_supported_version = "0.33.1"
r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
@@ -588,8 +590,8 @@ def __call__(
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -610,17 +612,17 @@ def __call__(
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
motion_field_strength_x (`float`, *optional*, defaults to 12):
- Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439),
- Sect. 3.3.1.
+ Strength of motion in generated video along x-axis. See the
+ [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
motion_field_strength_y (`float`, *optional*, defaults to 12):
- Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439),
- Sect. 3.3.1.
+ Strength of motion in generated video along y-axis. See the
+ [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
t0 (`int`, *optional*, defaults to 44):
Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the
- [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
+ [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
t1 (`int`, *optional*, defaults to 47):
Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
- [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
+ [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
frame_ids (`List[int]`, *optional*):
Indexes of the frames that are being generated. This is used when generating longer videos
chunk-by-chunk.
@@ -663,7 +665,7 @@ def __call__(
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
@@ -758,7 +760,7 @@ def __call__(
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
- torch.cuda.empty_cache()
+ empty_device_cache()
if output_type == "latent":
image = latents
@@ -797,7 +799,7 @@ def run_safety_checker(self, image, device, dtype):
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
index 339d5b3a6019..c8dce75e2671 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
@@ -19,23 +19,19 @@
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
-from ...models.attention_processor import (
- AttnProcessor2_0,
- FusedAttnProcessor2_0,
- XFormersAttnProcessor,
-)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
+ deprecate,
is_invisible_watermark_available,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
if is_invisible_watermark_available():
@@ -323,7 +319,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Flawed](https://huggingface.co/papers/2305.08891).
Args:
noise_cfg (`torch.Tensor`):
@@ -346,11 +342,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
class TextToVideoZeroSDXLPipeline(
+ DeprecatedPipelineMixin,
DiffusionPipeline,
StableDiffusionMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
):
+ _last_supported_version = "0.33.1"
r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion XL.
@@ -439,7 +437,7 @@ def __init__(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -455,22 +453,12 @@ def prepare_extra_step_kwargs(self, generator, eta):
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
def upcast_vae(self):
- dtype = self.vae.dtype
- self.vae.to(dtype=torch.float32)
- use_torch_2_0_or_xformers = isinstance(
- self.vae.decoder.mid_block.attentions[0].processor,
- (
- AttnProcessor2_0,
- XFormersAttnProcessor,
- FusedAttnProcessor2_0,
- ),
+ deprecate(
+ "upcast_vae",
+ "1.0.0",
+ "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.",
)
- # if xformers or torch_2_0 is used attention block does not need
- # to be in float32 which can save lots of memory
- if use_torch_2_0_or_xformers:
- self.vae.post_quant_conv.to(dtype)
- self.vae.decoder.conv_in.to(dtype)
- self.vae.decoder.mid_block.to(dtype)
+ self.vae.to(dtype=torch.float32)
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
def _get_add_time_ids(
@@ -929,7 +917,7 @@ def backward_loop(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
@@ -1009,11 +997,11 @@ def __call__(
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1024,8 +1012,8 @@ def __call__(
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
@@ -1049,13 +1037,13 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
motion_field_strength_x (`float`, *optional*, defaults to 12):
- Strength of motion in generated video along x-axis. See the [paper](https://arxiv.org/abs/2303.13439),
- Sect. 3.3.1.
+ Strength of motion in generated video along x-axis. See the
+ [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
motion_field_strength_y (`float`, *optional*, defaults to 12):
- Strength of motion in generated video along y-axis. See the [paper](https://arxiv.org/abs/2303.13439),
- Sect. 3.3.1.
+ Strength of motion in generated video along y-axis. See the
+ [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1074,9 +1062,10 @@ def __call__(
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
- [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
- Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
@@ -1093,10 +1082,10 @@ def __call__(
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
t0 (`int`, *optional*, defaults to 44):
Timestep t0. Should be in the range [0, num_inference_steps - 1]. See the
- [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
+ [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
t1 (`int`, *optional*, defaults to 47):
Timestep t0. Should be in the range [t0 + 1, num_inference_steps - 1]. See the
- [paper](https://arxiv.org/abs/2303.13439), Sect. 3.3.1.
+ [paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
Returns:
[`~pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoSDXLPipelineOutput`] or
@@ -1153,7 +1142,7 @@ def __call__(
)
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py
index bf42d44f74c1..bbb9b0eb3ab2 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@
from ...schedulers import UnCLIPScheduler
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
@@ -38,7 +38,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class UnCLIPPipeline(DiffusionPipeline):
+class UnCLIPPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for text-to-image generation using unCLIP.
@@ -69,6 +69,7 @@ class UnCLIPPipeline(DiffusionPipeline):
"""
+ _last_supported_version = "0.33.1"
_exclude_from_cpu_offload = ["prior"]
prior: PriorTransformer
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
index 8fa0a848f7e7..31710a000e0a 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@
from ...schedulers import UnCLIPScheduler
from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
@@ -43,7 +43,7 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class UnCLIPImageVariationPipeline(DiffusionPipeline):
+class UnCLIPImageVariationPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline to generate image variations from an input image using UnCLIP.
@@ -73,6 +73,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]).
"""
+ _last_supported_version = "0.33.1"
decoder: UNet2DConditionModel
text_proj: UnCLIPTextProjModel
text_encoder: CLIPTextModelWithProjection
diff --git a/src/diffusers/pipelines/unclip/text_proj.py b/src/diffusers/pipelines/unclip/text_proj.py
index 5a86d0c08a8d..5e04e48ba621 100644
--- a/src/diffusers/pipelines/unclip/text_proj.py
+++ b/src/diffusers/pipelines/unclip/text_proj.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the
decoder.
- For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1
+ For more details, see the original paper: https://huggingface.co/papers/2204.06125 section 2.1
"""
@register_to_config
diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
index 75e5d43678d5..0ddcbf735770 100644
--- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
+++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py
@@ -13,7 +13,7 @@
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin):
"""
- Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
+ Text decoder model for a image-text [UniDiffuser](https://huggingface.co/papers/2303.06555) model. This is used to
generate text from the UniDiffuser image-text embedding.
Parameters:
@@ -140,7 +140,7 @@ def forward(
input_ids (`torch.Tensor` of shape `(N, max_seq_len)`):
Text tokens to use for inference.
prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`):
- Prefix embedding to preprend to the embedded tokens.
+ Prefix embedding to prepend to the embedded tokens.
attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*):
Attention mask for the prefix embedding.
labels (`torch.Tensor`, *optional*):
diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
index 1e285a9670e2..2a04ec2e4030 100644
--- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
+++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
@@ -832,7 +832,7 @@ def forward(
class UniDiffuserModel(ModelMixin, ConfigMixin):
"""
- Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a
+ Transformer model for a image-text [UniDiffuser](https://huggingface.co/papers/2303.06555) model. This is a
modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the
CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details).
diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
index 66d7404fb9a5..f9298d5b86f8 100644
--- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
+++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
@@ -28,7 +28,7 @@
)
from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
from .modeling_text_decoder import UniDiffuserTextDecoder
from .modeling_uvit import UniDiffuserModel
@@ -62,7 +62,7 @@ class ImageTextPipelineOutput(BaseOutput):
text: Optional[Union[List[str], List[List[str]]]]
-class UniDiffuserPipeline(DiffusionPipeline):
+class UniDiffuserPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
r"""
Pipeline for a bimodal image-text model which supports unconditional text and image generation, text-conditioned
image generation, image-conditioned text generation, and joint image-text generation.
@@ -96,6 +96,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler.
"""
+ _last_supported_version = "0.33.1"
# TODO: support for moving submodules for components with enable_model_cpu_offload
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae->text_decoder"
@@ -153,7 +154,7 @@ def __init__(
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -231,6 +232,12 @@ def enable_vae_slicing(self):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
@@ -239,6 +246,12 @@ def disable_vae_slicing(self):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling
@@ -248,6 +261,12 @@ def enable_vae_tiling(self):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling
@@ -256,6 +275,12 @@ def disable_vae_tiling(self):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Functions to manually set the mode
@@ -803,7 +828,7 @@ def _split(self, x, height, width):
def _combine(self, img_vae, img_clip):
r"""
- Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1,
+ Combines a latent image img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1,
clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim).
"""
img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1))
@@ -1154,8 +1179,8 @@ def __call__(
`text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are
supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples are generated.
eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
@@ -1243,7 +1268,7 @@ def __call__(
reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img"
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# Note that this differs from the formulation in the unidiffusers paper!
do_classifier_free_guidance = guidance_scale > 1.0
diff --git a/src/diffusers/pipelines/visualcloze/__init__.py b/src/diffusers/pipelines/visualcloze/__init__.py
new file mode 100644
index 000000000000..ab765a1bbad9
--- /dev/null
+++ b/src/diffusers/pipelines/visualcloze/__init__.py
@@ -0,0 +1,52 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_visualcloze_combined"] = ["VisualClozePipeline"]
+ _import_structure["pipeline_visualcloze_generation"] = ["VisualClozeGenerationPipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_visualcloze_combined import VisualClozePipeline
+ from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py
new file mode 100644
index 000000000000..91a54e1ae82f
--- /dev/null
+++ b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py
@@ -0,0 +1,440 @@
+# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from PIL import Image
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ..flux.pipeline_flux_fill import FluxFillPipeline as VisualClozeUpsamplingPipeline
+from ..flux.pipeline_output import FluxPipelineOutput
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline
+
+
+if is_torch_xla_available():
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import VisualClozePipeline
+ >>> from diffusers.utils import load_image
+
+ >>> image_paths = [
+ ... # in-context examples
+ ... [
+ ... load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
+ ... ),
+ ... load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
+ ... ),
+ ... ],
+ ... # query with the target image
+ ... [
+ ... load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
+ ... ),
+ ... None, # No image needed for the target image
+ ... ],
+ ... ]
+ >>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
+ >>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
+ >>> pipe = VisualClozePipeline.from_pretrained(
+ ... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(
+ ... task_prompt=task_prompt,
+ ... content_prompt=content_prompt,
+ ... image=image_paths,
+ ... upsampling_width=1344,
+ ... upsampling_height=768,
+ ... upsampling_strength=0.4,
+ ... guidance_scale=30,
+ ... num_inference_steps=30,
+ ... max_sequence_length=512,
+ ... generator=torch.Generator("cpu").manual_seed(0),
+ ... ).images[0][0]
+ >>> image.save("visualcloze.png")
+ ```
+"""
+
+
+class VisualClozePipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ The VisualCloze pipeline for image generation with visual context. Reference:
+ https://github.com/lzyhha/VisualCloze/tree/main. This pipeline is designed to generate images based on visual
+ in-context examples.
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ resolution (`int`, *optional*, defaults to 384):
+ The resolution of each image when concatenating images from the query and in-context examples.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ resolution: int = 384,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.generation_pipe = VisualClozeGenerationPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ resolution=resolution,
+ )
+ self.upsampling_pipe = VisualClozeUpsamplingPipeline(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ def check_inputs(
+ self,
+ image,
+ task_prompt,
+ content_prompt,
+ upsampling_height,
+ upsampling_width,
+ strength,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if upsampling_height is not None and upsampling_height % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`upsampling_height`has to be divisible by {self.vae_scale_factor * 2} but are {upsampling_height}. Dimensions will be resized accordingly"
+ )
+ if upsampling_width is not None and upsampling_width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`upsampling_width` have to be divisible by {self.vae_scale_factor * 2} but are {upsampling_width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ # Validate prompt inputs
+ if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
+ raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
+
+ if task_prompt is None and content_prompt is None and prompt_embeds is None:
+ raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
+
+ # Validate prompt types and consistency
+ if task_prompt is None:
+ raise ValueError("`task_prompt` is missing.")
+
+ if task_prompt is not None and not isinstance(task_prompt, (str, list)):
+ raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}")
+
+ if content_prompt is not None and not isinstance(content_prompt, (str, list)):
+ raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}")
+
+ if isinstance(task_prompt, list) or isinstance(content_prompt, list):
+ if not isinstance(task_prompt, list) or not isinstance(content_prompt, list):
+ raise ValueError(
+ f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, "
+ f"got {type(task_prompt)} and {type(content_prompt)}"
+ )
+ if len(content_prompt) != len(task_prompt):
+ raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.")
+
+ for sample in image:
+ if not isinstance(sample, list) or not isinstance(sample[0], list):
+ raise ValueError("Each sample in the batch must have a 2D list of images.")
+ if len({len(row) for row in sample}) != 1:
+ raise ValueError("Each in-context example and query should contain the same number of images.")
+ if not any(img is None for img in sample[-1]):
+ raise ValueError("There are no targets in the query, which should be represented as None.")
+ for row in sample[:-1]:
+ if any(img is None for img in row):
+ raise ValueError("Images are missing in in-context examples.")
+
+ # Validate embeddings
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ # Validate sequence length
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}")
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ task_prompt: Union[str, List[str]] = None,
+ content_prompt: Union[str, List[str]] = None,
+ image: Optional[torch.FloatTensor] = None,
+ upsampling_height: Optional[int] = None,
+ upsampling_width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 30.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ upsampling_strength: float = 1.0,
+ ):
+ r"""
+ Function invoked when calling the VisualCloze pipeline for generation.
+
+ Args:
+ task_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to define the task intention.
+ content_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to define the content or caption of the target image to be generated.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
+ upsampling_height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By
+ default, the image is upsampled by a factor of three, and the base resolution is determined by the
+ resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is
+ specified, the other will be automatically set based on the aspect ratio.
+ upsampling_width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By
+ default, the image is upsampled by a factor of three, and the base resolution is determined by the
+ resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is
+ specified, the other will be automatically set based on the aspect ratio.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 30.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+ upsampling_strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image` when upsampling the results. Must be between 0 and
+ 1. The generated image is used as a starting point and more noise is added the higher the
+ `upsampling_strength`. The number of denoising steps depends on the amount of noise initially added.
+ When `upsampling_strength` is 1, added noise is maximum and the denoising process runs for the full
+ number of iterations specified in `num_inference_steps`. A value of 0 skips the upsampling step and
+ output the results at the resolution of `self.resolution`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ generation_output = self.generation_pipe(
+ task_prompt=task_prompt,
+ content_prompt=content_prompt,
+ image=image,
+ num_inference_steps=num_inference_steps,
+ sigmas=sigmas,
+ guidance_scale=guidance_scale,
+ num_images_per_prompt=num_images_per_prompt,
+ generator=generator,
+ latents=latents,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ joint_attention_kwargs=joint_attention_kwargs,
+ callback_on_step_end=callback_on_step_end,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ output_type=output_type if upsampling_strength == 0 else "pil",
+ )
+ if upsampling_strength == 0:
+ if not return_dict:
+ return (generation_output,)
+
+ return FluxPipelineOutput(images=generation_output)
+
+ # Upsampling the generated images
+ # 1. Prepare the input images and prompts
+ if not isinstance(content_prompt, (list)):
+ content_prompt = [content_prompt]
+ n_target_per_sample = []
+ upsampling_image = []
+ upsampling_mask = []
+ upsampling_prompt = []
+ upsampling_generator = generator if isinstance(generator, (torch.Generator,)) else []
+ for i in range(len(generation_output.images)):
+ n_target_per_sample.append(len(generation_output.images[i]))
+ for image in generation_output.images[i]:
+ upsampling_image.append(image)
+ upsampling_mask.append(Image.new("RGB", image.size, (255, 255, 255)))
+ upsampling_prompt.append(
+ content_prompt[i % len(content_prompt)] if content_prompt[i % len(content_prompt)] else ""
+ )
+ if not isinstance(generator, (torch.Generator,)):
+ upsampling_generator.append(generator[i % len(content_prompt)])
+
+ # 2. Apply the denosing loop
+ upsampling_output = self.upsampling_pipe(
+ prompt=upsampling_prompt,
+ image=upsampling_image,
+ mask_image=upsampling_mask,
+ height=upsampling_height,
+ width=upsampling_width,
+ strength=upsampling_strength,
+ num_inference_steps=num_inference_steps,
+ sigmas=sigmas,
+ guidance_scale=guidance_scale,
+ generator=upsampling_generator,
+ output_type=output_type,
+ joint_attention_kwargs=joint_attention_kwargs,
+ callback_on_step_end=callback_on_step_end,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+ image = upsampling_output.images
+
+ output = []
+ if output_type == "pil":
+ # Each sample in the batch may have multiple output images. When returning as PIL images,
+ # these images cannot be concatenated. Therefore, for each sample,
+ # a list is used to represent all the output images.
+ output = []
+ start = 0
+ for n in n_target_per_sample:
+ output.append(image[start : start + n])
+ start += n
+ else:
+ output = image
+
+ if not return_dict:
+ return (output,)
+
+ return FluxPipelineOutput(images=output)
diff --git a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py
new file mode 100644
index 000000000000..e12995106bcf
--- /dev/null
+++ b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py
@@ -0,0 +1,977 @@
+# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..flux.pipeline_flux_fill import calculate_shift, retrieve_latents, retrieve_timesteps
+from ..flux.pipeline_output import FluxPipelineOutput
+from ..pipeline_utils import DiffusionPipeline
+from .visualcloze_utils import VisualClozeProcessor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
+ >>> from diffusers.utils import load_image
+ >>> from PIL import Image
+
+ >>> image_paths = [
+ ... # in-context examples
+ ... [
+ ... load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
+ ... ),
+ ... load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
+ ... ),
+ ... ],
+ ... # query with the target image
+ ... [
+ ... load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
+ ... ),
+ ... None, # No image needed for the target image
+ ... ],
+ ... ]
+ >>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
+ >>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
+ >>> pipe = VisualClozeGenerationPipeline.from_pretrained(
+ ... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(
+ ... task_prompt=task_prompt,
+ ... content_prompt=content_prompt,
+ ... image=image_paths,
+ ... guidance_scale=30,
+ ... num_inference_steps=30,
+ ... max_sequence_length=512,
+ ... generator=torch.Generator("cpu").manual_seed(0),
+ ... ).images[0][0]
+
+ >>> # optional, upsampling the generated image
+ >>> pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
+ >>> pipe_upsample.to("cuda")
+
+ >>> mask_image = Image.new("RGB", image.size, (255, 255, 255))
+
+ >>> image = pipe_upsample(
+ ... image=image,
+ ... mask_image=mask_image,
+ ... prompt=content_prompt,
+ ... width=1344,
+ ... height=768,
+ ... strength=0.4,
+ ... guidance_scale=30,
+ ... num_inference_steps=30,
+ ... max_sequence_length=512,
+ ... generator=torch.Generator("cpu").manual_seed(0),
+ ... ).images[0]
+
+ >>> image.save("visualcloze.png")
+ ```
+"""
+
+
+class VisualClozeGenerationPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ The VisualCloze pipeline for image generation with visual context. Reference:
+ https://github.com/lzyhha/VisualCloze/tree/main This pipeline is designed to generate images based on visual
+ in-context examples.
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ resolution (`int`, *optional*, defaults to 384):
+ The resolution of each image when concatenating images from the query and in-context examples.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ resolution: int = 384,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.resolution = resolution
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VisualClozeProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels, resolution=resolution
+ )
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Modified from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ layout_prompt: Union[str, List[str]],
+ task_prompt: Union[str, List[str]],
+ content_prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ layout_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to define the number of in-context examples and the number of images involved in
+ the task.
+ task_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to define the task intention.
+ content_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to define the content or caption of the target image to be generated.
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ if isinstance(layout_prompt, str):
+ layout_prompt = [layout_prompt]
+ task_prompt = [task_prompt]
+ content_prompt = [content_prompt]
+
+ def _preprocess(prompt, content=False):
+ if prompt is not None:
+ return f"The last image of the last row depicts: {prompt}" if content else prompt
+ else:
+ return ""
+
+ prompt = [
+ f"{_preprocess(layout_prompt[i])} {_preprocess(task_prompt[i])} {_preprocess(content_prompt[i], content=True)}".strip()
+ for i in range(len(layout_prompt))
+ ]
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def check_inputs(
+ self,
+ image,
+ task_prompt,
+ content_prompt,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ # Validate prompt inputs
+ if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
+ raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
+
+ if task_prompt is None and content_prompt is None and prompt_embeds is None:
+ raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
+
+ # Validate prompt types and consistency
+ if task_prompt is None:
+ raise ValueError("`task_prompt` is missing.")
+
+ if task_prompt is not None and not isinstance(task_prompt, (str, list)):
+ raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}")
+
+ if content_prompt is not None and not isinstance(content_prompt, (str, list)):
+ raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}")
+
+ if isinstance(task_prompt, list) or isinstance(content_prompt, list):
+ if not isinstance(task_prompt, list) or not isinstance(content_prompt, list):
+ raise ValueError(
+ f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, "
+ f"got {type(task_prompt)} and {type(content_prompt)}"
+ )
+ if len(content_prompt) != len(task_prompt):
+ raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.")
+
+ for sample in image:
+ if not isinstance(sample, list) or not isinstance(sample[0], list):
+ raise ValueError("Each sample in the batch must have a 2D list of images.")
+ if len({len(row) for row in sample}) != 1:
+ raise ValueError("Each in-context example and query should contain the same number of images.")
+ if not any(img is None for img in sample[-1]):
+ raise ValueError("There are no targets in the query, which should be represented as None.")
+ for row in sample[:-1]:
+ if any(img is None for img in row):
+ raise ValueError("Images are missing in in-context examples.")
+
+ # Validate embeddings
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ # Validate sequence length
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}")
+
+ @staticmethod
+ def _prepare_latent_image_ids(image, vae_scale_factor, device, dtype):
+ latent_image_ids = []
+
+ for idx, img in enumerate(image, start=1):
+ img = img.squeeze(0)
+ channels, height, width = img.shape
+
+ num_patches_h = height // vae_scale_factor // 2
+ num_patches_w = width // vae_scale_factor // 2
+
+ patch_ids = torch.zeros(num_patches_h, num_patches_w, 3, device=device, dtype=dtype)
+ patch_ids[..., 0] = idx
+ patch_ids[..., 1] = torch.arange(num_patches_h, device=device, dtype=dtype)[:, None]
+ patch_ids[..., 2] = torch.arange(num_patches_w, device=device, dtype=dtype)[None, :]
+
+ patch_ids = patch_ids.reshape(-1, 3)
+ latent_image_ids.append(patch_ids)
+
+ return torch.cat(latent_image_ids, dim=0)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, sizes, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ start = 0
+ unpacked_latents = []
+ for i in range(len(sizes)):
+ cur_size = sizes[i]
+ height = cur_size[0][0] // vae_scale_factor
+ width = sum([size[1] for size in cur_size]) // vae_scale_factor
+
+ end = start + (height * width) // 4
+
+ cur_latents = latents[:, start:end]
+ cur_latents = cur_latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ cur_latents = cur_latents.permute(0, 3, 1, 4, 2, 5)
+ cur_latents = cur_latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ unpacked_latents.append(cur_latents)
+
+ start = end
+
+ return unpacked_latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def _prepare_latents(self, image, mask, gen, vae_scale_factor, device, dtype):
+ """Helper function to prepare latents for a single batch."""
+ # Concatenate images and masks along width dimension
+ image = [torch.cat(img, dim=3).to(device=device, dtype=dtype) for img in image]
+ mask = [torch.cat(m, dim=3).to(device=device, dtype=dtype) for m in mask]
+
+ # Generate latent image IDs
+ latent_image_ids = self._prepare_latent_image_ids(image, vae_scale_factor, device, dtype)
+
+ # For initial encoding, use actual images
+ image_latent = [self._encode_vae_image(img, gen) for img in image]
+ masked_image_latent = [img.clone() for img in image_latent]
+
+ for i in range(len(image_latent)):
+ # Rearrange latents and masks for patch processing
+ num_channels_latents, height, width = image_latent[i].shape[1:]
+ image_latent[i] = self._pack_latents(image_latent[i], 1, num_channels_latents, height, width)
+ masked_image_latent[i] = self._pack_latents(masked_image_latent[i], 1, num_channels_latents, height, width)
+
+ # Rearrange masks for patch processing
+ num_channels_latents, height, width = mask[i].shape[1:]
+ mask[i] = mask[i].view(
+ 1,
+ num_channels_latents,
+ height // vae_scale_factor,
+ vae_scale_factor,
+ width // vae_scale_factor,
+ vae_scale_factor,
+ )
+ mask[i] = mask[i].permute(0, 1, 3, 5, 2, 4)
+ mask[i] = mask[i].reshape(
+ 1,
+ num_channels_latents * (vae_scale_factor**2),
+ height // vae_scale_factor,
+ width // vae_scale_factor,
+ )
+ mask[i] = self._pack_latents(
+ mask[i],
+ 1,
+ num_channels_latents * (vae_scale_factor**2),
+ height // vae_scale_factor,
+ width // vae_scale_factor,
+ )
+
+ # Concatenate along batch dimension
+ image_latent = torch.cat(image_latent, dim=1)
+ masked_image_latent = torch.cat(masked_image_latent, dim=1)
+ mask = torch.cat(mask, dim=1)
+
+ return image_latent, masked_image_latent, mask, latent_image_ids
+
+ def prepare_latents(
+ self,
+ input_image,
+ input_mask,
+ timestep,
+ batch_size,
+ dtype,
+ device,
+ generator,
+ vae_scale_factor,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # Process each batch
+ masked_image_latents = []
+ image_latents = []
+ masks = []
+ latent_image_ids = []
+
+ for i in range(len(input_image)):
+ _image_latent, _masked_image_latent, _mask, _latent_image_ids = self._prepare_latents(
+ input_image[i],
+ input_mask[i],
+ generator if isinstance(generator, torch.Generator) else generator[i],
+ vae_scale_factor,
+ device,
+ dtype,
+ )
+ masked_image_latents.append(_masked_image_latent)
+ image_latents.append(_image_latent)
+ masks.append(_mask)
+ latent_image_ids.append(_latent_image_ids)
+
+ # Concatenate all batches
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
+ image_latents = torch.cat(image_latents, dim=0)
+ masks = torch.cat(masks, dim=0)
+
+ # Handle batch size expansion
+ if batch_size > masked_image_latents.shape[0]:
+ if batch_size % masked_image_latents.shape[0] == 0:
+ # Expand batches by repeating
+ additional_image_per_prompt = batch_size // masked_image_latents.shape[0]
+ masked_image_latents = torch.cat([masked_image_latents] * additional_image_per_prompt, dim=0)
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ masks = torch.cat([masks] * additional_image_per_prompt, dim=0)
+ else:
+ raise ValueError(
+ f"Cannot expand batch size from {masked_image_latents.shape[0]} to {batch_size}. "
+ "Batch sizes must be multiples of each other."
+ )
+
+ # Add noise to latents
+ noises = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noises).to(dtype=dtype)
+
+ # Combine masked latents with masks
+ masked_image_latents = torch.cat((masked_image_latents, masks), dim=-1).to(dtype=dtype)
+
+ return latents, masked_image_latents, latent_image_ids[0]
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ task_prompt: Union[str, List[str]] = None,
+ content_prompt: Union[str, List[str]] = None,
+ image: Optional[torch.FloatTensor] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 30.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the VisualCloze pipeline for generation.
+
+ Args:
+ task_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to define the task intention.
+ content_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to define the content or caption of the target image to be generated.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 30.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ image,
+ task_prompt,
+ content_prompt,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ processor_output = self.image_processor.preprocess(
+ task_prompt, content_prompt, image, vae_scale_factor=self.vae_scale_factor
+ )
+
+ # 2. Define call parameters
+ if processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], str):
+ batch_size = 1
+ elif processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], list):
+ batch_size = len(processor_output["task_prompt"])
+
+ device = self._execution_device
+
+ # 3. Prepare prompt embeddings
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
+ layout_prompt=processor_output["layout_prompt"],
+ task_prompt=processor_output["task_prompt"],
+ content_prompt=processor_output["content_prompt"],
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ # Calculate sequence length and shift factor
+ image_seq_len = sum(
+ (size[0] // self.vae_scale_factor // 2) * (size[1] // self.vae_scale_factor // 2)
+ for sample in processor_output["image_size"][0]
+ for size in sample
+ )
+
+ # Calculate noise schedule parameters
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+
+ # Get timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
+
+ # 5. Prepare latent variables
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ latents, masked_image_latents, latent_image_ids = self.prepare_latents(
+ processor_output["init_image"],
+ processor_output["mask"],
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ vae_scale_factor=self.vae_scale_factor,
+ )
+
+ # Calculate warmup steps
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Prepare guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ latent_model_input = torch.cat((latents, masked_image_latents), dim=2)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # Compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # Some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # Call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # XLA optimization
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # 7. Post-process the image
+ # Crop the target image
+ # Since the generated image is a concatenation of the conditional and target regions,
+ # we need to extract only the target regions based on their positions
+ image = []
+ if output_type == "latent":
+ image = latents
+ else:
+ for b in range(len(latents)):
+ cur_image_size = processor_output["image_size"][b % batch_size]
+ cur_target_position = processor_output["target_position"][b % batch_size]
+ cur_latent = self._unpack_latents(latents[b].unsqueeze(0), cur_image_size, self.vae_scale_factor)[-1]
+ cur_latent = (cur_latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ cur_image = self.vae.decode(cur_latent, return_dict=False)[0]
+ cur_image = self.image_processor.postprocess(cur_image, output_type=output_type)[0]
+
+ start = 0
+ cropped = []
+ for i, size in enumerate(cur_image_size[-1]):
+ if cur_target_position[i]:
+ if output_type == "pil":
+ cropped.append(cur_image.crop((start, 0, start + size[1], size[0])))
+ else:
+ cropped.append(cur_image[0 : size[0], start : start + size[1]])
+ start += size[1]
+ image.append(cropped)
+ if output_type != "pil":
+ image = np.concatenate([arr[None] for sub_image in image for arr in sub_image], axis=0)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/visualcloze/visualcloze_utils.py b/src/diffusers/pipelines/visualcloze/visualcloze_utils.py
new file mode 100644
index 000000000000..efe5dff47623
--- /dev/null
+++ b/src/diffusers/pipelines/visualcloze/visualcloze_utils.py
@@ -0,0 +1,251 @@
+# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+
+from ...image_processor import VaeImageProcessor
+
+
+class VisualClozeProcessor(VaeImageProcessor):
+ """
+ Image processor for the VisualCloze pipeline.
+
+ This processor handles the preprocessing of images for visual cloze tasks, including resizing, normalization, and
+ mask generation.
+
+ Args:
+ resolution (int, optional):
+ Target resolution for processing images. Each image will be resized to this resolution before being
+ concatenated to avoid the out-of-memory error. Defaults to 384.
+ *args: Additional arguments passed to [~image_processor.VaeImageProcessor]
+ **kwargs: Additional keyword arguments passed to [~image_processor.VaeImageProcessor]
+ """
+
+ def __init__(self, *args, resolution: int = 384, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.resolution = resolution
+
+ def preprocess_image(
+ self, input_images: List[List[Optional[Image.Image]]], vae_scale_factor: int
+ ) -> Tuple[List[List[torch.Tensor]], List[List[List[int]]], List[int]]:
+ """
+ Preprocesses input images for the VisualCloze pipeline.
+
+ This function handles the preprocessing of input images by:
+ 1. Resizing and cropping images to maintain consistent dimensions
+ 2. Converting images to the Tensor format for the VAE
+ 3. Normalizing pixel values
+ 4. Tracking image sizes and positions of target images
+
+ Args:
+ input_images (List[List[Optional[Image.Image]]]):
+ A nested list of PIL Images where:
+ - Outer list represents different samples, including in-context examples and the query
+ - Inner list contains images for the task
+ - In the last row, condition images are provided and the target images are placed as None
+ vae_scale_factor (int):
+ The scale factor used by the VAE for resizing images
+
+ Returns:
+ Tuple containing:
+ - List[List[torch.Tensor]]: Preprocessed images in tensor format
+ - List[List[List[int]]]: Dimensions of each processed image [height, width]
+ - List[int]: Target positions indicating which images are to be generated
+ """
+ n_samples, n_task_images = len(input_images), len(input_images[0])
+ divisible = 2 * vae_scale_factor
+
+ processed_images: List[List[Image.Image]] = [[] for _ in range(n_samples)]
+ resize_size: List[Optional[Tuple[int, int]]] = [None for _ in range(n_samples)]
+ target_position: List[int] = []
+
+ # Process each sample
+ for i in range(n_samples):
+ # Determine size from first non-None image
+ for j in range(n_task_images):
+ if input_images[i][j] is not None:
+ aspect_ratio = input_images[i][j].width / input_images[i][j].height
+ target_area = self.resolution * self.resolution
+ new_h = int((target_area / aspect_ratio) ** 0.5)
+ new_w = int(new_h * aspect_ratio)
+
+ new_w = max(new_w // divisible, 1) * divisible
+ new_h = max(new_h // divisible, 1) * divisible
+ resize_size[i] = (new_w, new_h)
+ break
+
+ # Process all images in the sample
+ for j in range(n_task_images):
+ if input_images[i][j] is not None:
+ target = self._resize_and_crop(input_images[i][j], resize_size[i][0], resize_size[i][1])
+ processed_images[i].append(target)
+ if i == n_samples - 1:
+ target_position.append(0)
+ else:
+ blank = Image.new("RGB", resize_size[i] or (self.resolution, self.resolution), (0, 0, 0))
+ processed_images[i].append(blank)
+ if i == n_samples - 1:
+ target_position.append(1)
+
+ # Ensure consistent width for multiple target images when there are multiple target images
+ if len(target_position) > 1 and sum(target_position) > 1:
+ new_w = resize_size[n_samples - 1][0] or 384
+ for i in range(len(processed_images)):
+ for j in range(len(processed_images[i])):
+ if processed_images[i][j] is not None:
+ new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width))
+ new_w = int(new_w / 16) * 16
+ new_h = int(new_h / 16) * 16
+ processed_images[i][j] = self._resize_and_crop(processed_images[i][j], new_h, new_w)
+
+ # Convert to tensors and normalize
+ image_sizes = []
+ for i in range(len(processed_images)):
+ image_sizes.append([[img.height, img.width] for img in processed_images[i]])
+ for j, image in enumerate(processed_images[i]):
+ image = self.pil_to_numpy(image)
+ image = self.numpy_to_pt(image)
+ image = self.normalize(image)
+ processed_images[i][j] = image
+
+ return processed_images, image_sizes, target_position
+
+ def preprocess_mask(
+ self, input_images: List[List[Image.Image]], target_position: List[int]
+ ) -> List[List[torch.Tensor]]:
+ """
+ Generate masks for the VisualCloze pipeline.
+
+ Args:
+ input_images (List[List[Image.Image]]):
+ Processed images from preprocess_image
+ target_position (List[int]):
+ Binary list marking the positions of target images (1 for target, 0 for condition)
+
+ Returns:
+ List[List[torch.Tensor]]:
+ A nested list of mask tensors (1 for target positions, 0 for condition images)
+ """
+ mask = []
+ for i, row in enumerate(input_images):
+ if i == len(input_images) - 1: # Query row
+ row_masks = [
+ torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=m) for m in target_position
+ ]
+ else: # In-context examples
+ row_masks = [
+ torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=0) for _ in target_position
+ ]
+ mask.append(row_masks)
+ return mask
+
+ def preprocess_image_upsampling(
+ self,
+ input_images: List[List[Image.Image]],
+ height: int,
+ width: int,
+ ) -> Tuple[List[List[Image.Image]], List[List[List[int]]]]:
+ """Process images for the upsampling stage in the VisualCloze pipeline.
+
+ Args:
+ input_images: Input image to process
+ height: Target height
+ width: Target width
+
+ Returns:
+ Tuple of processed image and its size
+ """
+ image = self.resize(input_images[0][0], height, width)
+ image = self.pil_to_numpy(image) # to np
+ image = self.numpy_to_pt(image) # to pt
+ image = self.normalize(image)
+
+ input_images[0][0] = image
+ image_sizes = [[[height, width]]]
+ return input_images, image_sizes
+
+ def preprocess_mask_upsampling(self, input_images: List[List[Image.Image]]) -> List[List[torch.Tensor]]:
+ return [[torch.ones((1, 1, input_images[0][0].shape[2], input_images[0][0].shape[3]))]]
+
+ def get_layout_prompt(self, size: Tuple[int, int]) -> str:
+ layout_instruction = (
+ f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.",
+ )
+ return layout_instruction
+
+ def preprocess(
+ self,
+ task_prompt: Union[str, List[str]],
+ content_prompt: Union[str, List[str]],
+ input_images: Optional[List[List[List[Optional[str]]]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ upsampling: bool = False,
+ vae_scale_factor: int = 16,
+ ) -> Dict:
+ """Process visual cloze inputs.
+
+ Args:
+ task_prompt: Task description(s)
+ content_prompt: Content description(s)
+ input_images: List of images or None for the target images
+ height: Optional target height for upsampling stage
+ width: Optional target width for upsampling stage
+ upsampling: Whether this is in the upsampling processing stage
+
+ Returns:
+ Dictionary containing processed images, masks, prompts and metadata
+ """
+ if isinstance(task_prompt, str):
+ task_prompt = [task_prompt]
+ content_prompt = [content_prompt]
+ input_images = [input_images]
+
+ output = {
+ "init_image": [],
+ "mask": [],
+ "task_prompt": task_prompt if not upsampling else [None for _ in range(len(task_prompt))],
+ "content_prompt": content_prompt,
+ "layout_prompt": [],
+ "target_position": [],
+ "image_size": [],
+ }
+ for i in range(len(task_prompt)):
+ if upsampling:
+ layout_prompt = None
+ else:
+ layout_prompt = self.get_layout_prompt((len(input_images[i]), len(input_images[i][0])))
+
+ if upsampling:
+ cur_processed_images, cur_image_size = self.preprocess_image_upsampling(
+ input_images[i], height=height, width=width
+ )
+ cur_mask = self.preprocess_mask_upsampling(cur_processed_images)
+ else:
+ cur_processed_images, cur_image_size, cur_target_position = self.preprocess_image(
+ input_images[i], vae_scale_factor=vae_scale_factor
+ )
+ cur_mask = self.preprocess_mask(cur_processed_images, cur_target_position)
+
+ output["target_position"].append(cur_target_position)
+
+ output["image_size"].append(cur_image_size)
+ output["init_image"].append(cur_processed_images)
+ output["mask"].append(cur_mask)
+ output["layout_prompt"].append(layout_prompt)
+
+ return output
diff --git a/src/diffusers/pipelines/wan/__init__.py b/src/diffusers/pipelines/wan/__init__.py
index 80916a8a1e10..ad51a52f9242 100644
--- a/src/diffusers/pipelines/wan/__init__.py
+++ b/src/diffusers/pipelines/wan/__init__.py
@@ -23,7 +23,9 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_wan"] = ["WanPipeline"]
+ _import_structure["pipeline_wan_animate"] = ["WanAnimatePipeline"]
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
+ _import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"]
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -34,9 +36,10 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_wan import WanPipeline
+ from .pipeline_wan_animate import WanAnimatePipeline
from .pipeline_wan_i2v import WanImageToVideoPipeline
+ from .pipeline_wan_vace import WanVACEPipeline
from .pipeline_wan_video2video import WanVideoToVideoPipeline
-
else:
import sys
diff --git a/src/diffusers/pipelines/wan/image_processor.py b/src/diffusers/pipelines/wan/image_processor.py
new file mode 100644
index 000000000000..b1594d08630f
--- /dev/null
+++ b/src/diffusers/pipelines/wan/image_processor.py
@@ -0,0 +1,185 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from ...configuration_utils import register_to_config
+from ...image_processor import VaeImageProcessor
+from ...utils import PIL_INTERPOLATION
+
+
+class WanAnimateImageProcessor(VaeImageProcessor):
+ r"""
+ Image processor to preprocess the reference (character) image for the Wan Animate model.
+
+ Args:
+ do_resize (`bool`, *optional*, defaults to `True`):
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
+ VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
+ this factor.
+ vae_latent_channels (`int`, *optional*, defaults to `16`):
+ VAE latent channels.
+ spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(2, 2)`):
+ The spatial patch size used by the diffusion transformer. For Wan models, this is typically (2, 2).
+ resample (`str`, *optional*, defaults to `lanczos`):
+ Resampling filter to use when resizing the image.
+ do_normalize (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the image to [-1,1].
+ do_binarize (`bool`, *optional*, defaults to `False`):
+ Whether to binarize the image to 0/1.
+ do_convert_rgb (`bool`, *optional*, defaults to be `False`):
+ Whether to convert the images to RGB format.
+ do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
+ Whether to convert the images to grayscale format.
+ fill_color (`str` or `float` or `Tuple[float, ...]`, *optional*, defaults to `None`):
+ An optional fill color when `resize_mode` is set to `"fill"`. This will fill the empty space with that
+ color instead of filling with data from the image. Any valid `color` argument to `PIL.Image.new` is valid;
+ if `None`, will default to filling with data from `image`.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 8,
+ vae_latent_channels: int = 16,
+ spatial_patch_size: Tuple[int, int] = (2, 2),
+ resample: str = "lanczos",
+ reducing_gap: int = None,
+ do_normalize: bool = True,
+ do_binarize: bool = False,
+ do_convert_rgb: bool = False,
+ do_convert_grayscale: bool = False,
+ fill_color: Optional[Union[str, float, Tuple[float, ...]]] = 0,
+ ):
+ super().__init__()
+ if do_convert_rgb and do_convert_grayscale:
+ raise ValueError(
+ "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
+ " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
+ " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
+ )
+
+ def _resize_and_fill(
+ self,
+ image: PIL.Image.Image,
+ width: int,
+ height: int,
+ ) -> PIL.Image.Image:
+ r"""
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
+ the image within the dimensions, filling empty with data from image.
+
+ Args:
+ image (`PIL.Image.Image`):
+ The image to resize and fill.
+ width (`int`):
+ The width to resize the image to.
+ height (`int`):
+ The height to resize the image to.
+
+ Returns:
+ `PIL.Image.Image`:
+ The resized and filled image.
+ """
+
+ ratio = width / height
+ src_ratio = image.width / image.height
+ fill_with_image_data = self.config.fill_color is None
+ fill_color = self.config.fill_color or 0
+
+ src_w = width if ratio < src_ratio else image.width * height // image.height
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
+
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
+ res = PIL.Image.new("RGB", (width, height), color=fill_color)
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+
+ if fill_with_image_data:
+ if ratio < src_ratio:
+ fill_height = height // 2 - src_h // 2
+ if fill_height > 0:
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
+ box=(0, fill_height + src_h),
+ )
+ elif ratio > src_ratio:
+ fill_width = width // 2 - src_w // 2
+ if fill_width > 0:
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
+ box=(fill_width + src_w, 0),
+ )
+
+ return res
+
+ def get_default_height_width(
+ self,
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ ) -> Tuple[int, int]:
+ r"""
+ Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
+
+ Args:
+ image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
+ The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
+ should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
+ tensor, it should have shape `[batch, channels, height, width]`.
+ height (`Optional[int]`, *optional*, defaults to `None`):
+ The height of the preprocessed image. If `None`, the height of the `image` input will be used.
+ width (`Optional[int]`, *optional*, defaults to `None`):
+ The width of the preprocessed image. If `None`, the width of the `image` input will be used.
+
+ Returns:
+ `Tuple[int, int]`:
+ A tuple containing the height and width, both resized to the nearest integer multiple of
+ `vae_scale_factor * spatial_patch_size`.
+ """
+
+ if height is None:
+ if isinstance(image, PIL.Image.Image):
+ height = image.height
+ elif isinstance(image, torch.Tensor):
+ height = image.shape[2]
+ else:
+ height = image.shape[1]
+
+ if width is None:
+ if isinstance(image, PIL.Image.Image):
+ width = image.width
+ elif isinstance(image, torch.Tensor):
+ width = image.shape[3]
+ else:
+ width = image.shape[2]
+
+ max_area = width * height
+ aspect_ratio = height / width
+ mod_value_h = self.config.vae_scale_factor * self.config.spatial_patch_size[0]
+ mod_value_w = self.config.vae_scale_factor * self.config.spatial_patch_size[1]
+
+ # Try to preserve the aspect ratio
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_h * mod_value_h
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_w * mod_value_w
+
+ return height, width
diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py
index 3294e9a56a07..78fe71ea9138 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan.py
@@ -15,7 +15,6 @@
import html
from typing import Any, Callable, Dict, List, Optional, Union
-import ftfy
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
@@ -24,7 +23,7 @@
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -40,6 +39,9 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+if is_ftfy_available():
+ import ftfy
+
EXAMPLE_DOC_STRING = """
Examples:
@@ -110,18 +112,31 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
+ two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
+ stages. If not provided, only `transformer` is used.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
- model_cpu_offload_seq = "text_encoder->transformer->vae"
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer", "transformer_2"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
- transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer: Optional[WanTransformer3DModel] = None,
+ transformer_2: Optional[WanTransformer3DModel] = None,
+ boundary_ratio: Optional[float] = None,
+ expand_timesteps: bool = False, # Wan2.2 ti2v
):
super().__init__()
@@ -131,10 +146,12 @@ def __init__(
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
+ transformer_2=transformer_2,
)
-
- self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
- self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.register_to_config(boundary_ratio=boundary_ratio)
+ self.register_to_config(expand_timesteps=expand_timesteps)
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
@@ -268,6 +285,7 @@ def check_inputs(
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
+ guidance_scale_2=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -300,6 +318,9 @@ def check_inputs(
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
+
def prepare_latents(
self,
batch_size: int,
@@ -367,6 +388,7 @@ def __call__(
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -386,8 +408,10 @@ def __call__(
Args:
prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
- instead.
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
height (`int`, defaults to `480`):
The height in pixels of the generated image.
width (`int`, defaults to `832`):
@@ -398,11 +422,15 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -415,7 +443,7 @@ def __call__(
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
- output_type (`str`, *optional*, defaults to `"pil"`):
+ output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
@@ -432,8 +460,9 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
- autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
- The dtype to use for the torch.amp.autocast.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
Examples:
@@ -456,6 +485,7 @@ def __call__(
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -465,7 +495,11 @@ def __call__(
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -492,7 +526,7 @@ def __call__(
device=device,
)
- transformer_dtype = self.transformer.dtype
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
@@ -502,7 +536,11 @@ def __call__(
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
- num_channels_latents = self.transformer.config.in_channels
+ num_channels_latents = (
+ self.transformer.config.in_channels
+ if self.transformer is not None
+ else self.transformer_2.config.in_channels
+ )
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -515,36 +553,61 @@ def __call__(
latents,
)
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
+
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
- latent_model_input = latents.to(transformer_dtype)
- timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ latent_model_input = latents.to(transformer_dtype)
+ if self.config.expand_timesteps:
+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ # batch_size, seq_len
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ timestep = t.expand(latents.shape[0])
+
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_animate.py b/src/diffusers/pipelines/wan/pipeline_wan_animate.py
new file mode 100644
index 000000000000..c7c983b2f7d4
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan_animate.py
@@ -0,0 +1,1204 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import regex as re
+import torch
+import torch.nn.functional as F
+from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanAnimateTransformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .image_processor import WanAnimateImageProcessor
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> import numpy as np
+ >>> from diffusers import WanAnimatePipeline
+ >>> from diffusers.utils import export_to_video, load_image, load_video
+
+ >>> model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
+ >>> pipe = WanAnimatePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> # Optionally upcast the Wan VAE to FP32
+ >>> pipe.vae.to(torch.float32)
+ >>> pipe.to("cuda")
+
+ >>> # Load the reference character image
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
+ ... )
+
+ >>> # Load pose and face videos (preprocessed from reference video)
+ >>> # Note: Videos should be preprocessed to extract pose keypoints and face features
+ >>> # Refer to the Wan-Animate preprocessing documentation for details
+ >>> pose_video = load_video("path/to/pose_video.mp4")
+ >>> face_video = load_video("path/to/face_video.mp4")
+
+ >>> # CFG is generally not used for Wan Animate
+ >>> prompt = (
+ ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
+ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
+ ... )
+
+ >>> # Animation mode: Animate the character with the motion from pose/face videos
+ >>> output = pipe(
+ ... image=image,
+ ... pose_video=pose_video,
+ ... face_video=face_video,
+ ... prompt=prompt,
+ ... height=height,
+ ... width=width,
+ ... segment_frame_length=77, # Frame length of each inference segment
+ ... guidance_scale=1.0,
+ ... num_inference_steps=20,
+ ... mode="animate",
+ ... ).frames[0]
+ >>> export_to_video(output, "output_animation.mp4", fps=30)
+
+ >>> # Replacement mode: Replace a character in the background video
+ >>> # Requires additional background_video and mask_video inputs
+ >>> background_video = load_video("path/to/background_video.mp4")
+ >>> mask_video = load_video("path/to/mask_video.mp4") # Black areas preserved, white areas generated
+ >>> output = pipe(
+ ... image=image,
+ ... pose_video=pose_video,
+ ... face_video=face_video,
+ ... background_video=background_video,
+ ... mask_video=mask_video,
+ ... prompt=prompt,
+ ... height=height,
+ ... width=width,
+ ... segment_frame_length=77, # Frame length of each inference segment
+ ... guidance_scale=1.0,
+ ... num_inference_steps=20,
+ ... mode="replace",
+ ... ).frames[0]
+ >>> export_to_video(output, "output_replacement.mp4", fps=30)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class WanAnimatePipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for unified character animation and replacement using Wan-Animate.
+
+ WanAnimatePipeline takes a character image, pose video, and face video as input, and generates a video in two
+ modes:
+
+ 1. **Animation mode**: The model generates a video of the character image that mimics the human motion in the input
+ pose and face videos. The character is animated based on the provided motion controls, creating a new animated
+ video of the character.
+
+ 2. **Replacement mode**: The model replaces a character in a background video with the provided character image,
+ using the pose and face videos for motion control. This mode requires additional `background_video` and
+ `mask_video` inputs. The mask video should have black regions where the original content should be preserved and
+ white regions where the new character should be generated.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.WanLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
+ the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`WanAnimateTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ image_processor ([`CLIPImageProcessor`]):
+ Image processor for preprocessing images before encoding.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ image_processor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModel,
+ transformer: WanAnimateTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.video_processor_for_mask = VideoProcessor(
+ vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_convert_grayscale=True
+ )
+ # In case self.transformer is None (e.g. for some pipeline tests)
+ spatial_patch_size = self.transformer.config.patch_size[-2:] if self.transformer is not None else (2, 2)
+ self.vae_image_processor = WanAnimateImageProcessor(
+ vae_scale_factor=self.vae_scale_factor_spatial,
+ spatial_patch_size=spatial_patch_size,
+ resample="bilinear",
+ fill_color=0,
+ )
+ self.image_processor = image_processor
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image
+ def encode_image(
+ self,
+ image: PipelineImageInput,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ pose_video,
+ face_video,
+ background_video,
+ mask_video,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ mode=None,
+ prev_segment_conditioning_frames=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if pose_video is None:
+ raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.")
+ if face_video is None:
+ raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.")
+ if not isinstance(pose_video, list) or not isinstance(face_video, list):
+ raise ValueError("`pose_video` and `face_video` must be lists of PIL images.")
+ if len(pose_video) == 0 or len(face_video) == 0:
+ raise ValueError("`pose_video` and `face_video` must contain at least one frame.")
+ if mode == "replace" and (background_video is None or mask_video is None):
+ raise ValueError(
+ "Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video`"
+ " undefined when mode is `replace`."
+ )
+ if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)):
+ raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.")
+
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found"
+ f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if mode is not None and (not isinstance(mode, str) or mode not in ("animate", "replace")):
+ raise ValueError(
+ f"`mode` has to be of type `str` and in ('animate', 'replace') but its type is {type(mode)} and value is {mode}"
+ )
+
+ if prev_segment_conditioning_frames is not None and (
+ not isinstance(prev_segment_conditioning_frames, int) or prev_segment_conditioning_frames not in (1, 5)
+ ):
+ raise ValueError(
+ f"`prev_segment_conditioning_frames` has to be of type `int` and 1 or 5 but its type is"
+ f" {type(prev_segment_conditioning_frames)} and value is {prev_segment_conditioning_frames}"
+ )
+
+ def get_i2v_mask(
+ self,
+ batch_size: int,
+ latent_t: int,
+ latent_h: int,
+ latent_w: int,
+ mask_len: int = 1,
+ mask_pixel_values: Optional[torch.Tensor] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Union[str, torch.device] = "cuda",
+ ) -> torch.Tensor:
+ # mask_pixel_values shape (if supplied): [B, C = 1, T, latent_h, latent_w]
+ if mask_pixel_values is None:
+ mask_lat_size = torch.zeros(
+ batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, dtype=dtype, device=device
+ )
+ else:
+ mask_lat_size = mask_pixel_values.clone().to(device=device, dtype=dtype)
+ mask_lat_size[:, :, :mask_len] = 1
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ # Repeat first frame mask self.vae_scale_factor_temporal (= 4) times in the frame dimension
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2)
+ mask_lat_size = mask_lat_size.view(
+ batch_size, -1, self.vae_scale_factor_temporal, latent_h, latent_w
+ ).transpose(1, 2) # [B, C = 1, 4 * T_lat, H_lat, W_lat] --> [B, C = 4, T_lat, H_lat, W_lat]
+
+ return mask_lat_size
+
+ def prepare_reference_image_latents(
+ self,
+ image: torch.Tensor,
+ batch_size: int = 1,
+ sample_mode: int = "argmax",
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ # image shape: (B, C, H, W) or (B, C, T, H, W)
+ dtype = dtype or self.vae.dtype
+ if image.ndim == 4:
+ # Add a singleton frame dimension after the channels dimension
+ image = image.unsqueeze(2)
+
+ _, _, _, height, width = image.shape
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ # Encode image to latents using VAE
+ image = image.to(device=device, dtype=dtype)
+ if isinstance(generator, list):
+ # Like in prepare_latents, assume len(generator) == batch_size
+ ref_image_latents = [
+ retrieve_latents(self.vae.encode(image), generator=g, sample_mode=sample_mode) for g in generator
+ ]
+ ref_image_latents = torch.cat(ref_image_latents)
+ else:
+ ref_image_latents = retrieve_latents(self.vae.encode(image), generator, sample_mode)
+ # Standardize latents in preparation for Wan VAE encode
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(ref_image_latents.device, ref_image_latents.dtype)
+ )
+ latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ ref_image_latents.device, ref_image_latents.dtype
+ )
+ ref_image_latents = (ref_image_latents - latents_mean) * latents_recip_std
+ # Handle the case where we supply one image and one generator, but batch_size > 1 (e.g. generating multiple
+ # videos per prompt)
+ if ref_image_latents.shape[0] == 1 and batch_size > 1:
+ ref_image_latents = ref_image_latents.expand(batch_size, -1, -1, -1, -1)
+
+ # Prepare I2V mask in latent space and prepend to the reference image latents along channel dim
+ reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, dtype, device)
+ reference_image_latents = torch.cat([reference_image_mask, ref_image_latents], dim=1)
+
+ return reference_image_latents
+
+ def prepare_prev_segment_cond_latents(
+ self,
+ prev_segment_cond_video: Optional[torch.Tensor] = None,
+ background_video: Optional[torch.Tensor] = None,
+ mask_video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ segment_frame_length: int = 77,
+ start_frame: int = 0,
+ height: int = 720,
+ width: int = 1280,
+ prev_segment_cond_frames: int = 1,
+ task: str = "animate",
+ interpolation_mode: str = "bicubic",
+ sample_mode: str = "argmax",
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ # prev_segment_cond_video shape: (B, C, T, H, W) in pixel space if supplied
+ # background_video shape: (B, C, T, H, W) (same as prev_segment_cond_video shape)
+ # mask_video shape: (B, 1, T, H, W) (same as prev_segment_cond_video, but with only 1 channel)
+ dtype = dtype or self.vae.dtype
+ if prev_segment_cond_video is None:
+ if task == "replace":
+ prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames].to(dtype)
+ else:
+ cond_frames_shape = (batch_size, 3, prev_segment_cond_frames, height, width) # In pixel space
+ prev_segment_cond_video = torch.zeros(cond_frames_shape, dtype=dtype, device=device)
+
+ data_batch_size, channels, _, segment_height, segment_width = prev_segment_cond_video.shape
+ num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+ if segment_height != height or segment_width != width:
+ print(
+ f"Interpolating prev segment cond video from ({segment_width}, {segment_height}) to ({width}, {height})"
+ )
+ # Perform a 4D (spatial) rather than a 5D (spatiotemporal) reshape, following the original code
+ prev_segment_cond_video = prev_segment_cond_video.transpose(1, 2).flatten(0, 1) # [B * T, C, H, W]
+ prev_segment_cond_video = F.interpolate(
+ prev_segment_cond_video, size=(height, width), mode=interpolation_mode
+ )
+ prev_segment_cond_video = prev_segment_cond_video.unflatten(0, (batch_size, -1)).transpose(1, 2)
+
+ # Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if
+ # replacing).
+ if task == "replace":
+ remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype)
+ else:
+ remaining_segment_frames = segment_frame_length - prev_segment_cond_frames
+ remaining_segment = torch.zeros(
+ batch_size, channels, remaining_segment_frames, height, width, dtype=dtype, device=device
+ )
+
+ # Prepend the conditioning frames from the previous segment to the remaining segment video in the frame dim
+ prev_segment_cond_video = prev_segment_cond_video.to(dtype=dtype)
+ full_segment_cond_video = torch.cat([prev_segment_cond_video, remaining_segment], dim=2)
+
+ if isinstance(generator, list):
+ if data_batch_size == len(generator):
+ prev_segment_cond_latents = [
+ retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode)
+ for i, g in enumerate(generator)
+ ]
+ elif data_batch_size == 1:
+ # Like prepare_latents, assume len(generator) == batch_size
+ prev_segment_cond_latents = [
+ retrieve_latents(self.vae.encode(full_segment_cond_video), g, sample_mode) for g in generator
+ ]
+ else:
+ raise ValueError(
+ f"The batch size of the prev segment video should be either {len(generator)} or 1 but is"
+ f" {data_batch_size}"
+ )
+ prev_segment_cond_latents = torch.cat(prev_segment_cond_latents)
+ else:
+ prev_segment_cond_latents = retrieve_latents(
+ self.vae.encode(full_segment_cond_video), generator, sample_mode
+ )
+ # Standardize latents in preparation for Wan VAE encode
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(prev_segment_cond_latents.device, prev_segment_cond_latents.dtype)
+ )
+ latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ prev_segment_cond_latents.device, prev_segment_cond_latents.dtype
+ )
+ prev_segment_cond_latents = (prev_segment_cond_latents - latents_mean) * latents_recip_std
+
+ # Prepare I2V mask
+ if task == "replace":
+ mask_video = 1 - mask_video
+ mask_video = mask_video.permute(0, 2, 1, 3, 4)
+ mask_video = mask_video.flatten(0, 1)
+ mask_video = F.interpolate(mask_video, size=(latent_height, latent_width), mode="nearest")
+ mask_pixel_values = mask_video.unflatten(0, (batch_size, -1))
+ mask_pixel_values = mask_pixel_values.permute(0, 2, 1, 3, 4) # output shape: [B, C = 1, T, H_lat, W_lat]
+ else:
+ mask_pixel_values = None
+ prev_segment_cond_mask = self.get_i2v_mask(
+ batch_size,
+ num_latent_frames,
+ latent_height,
+ latent_width,
+ mask_len=prev_segment_cond_frames if start_frame > 0 else 0,
+ mask_pixel_values=mask_pixel_values,
+ dtype=dtype,
+ device=device,
+ )
+
+ # Prepend cond I2V mask to prev segment cond latents along channel dimension
+ prev_segment_cond_latents = torch.cat([prev_segment_cond_mask, prev_segment_cond_latents], dim=1)
+ return prev_segment_cond_latents
+
+ def prepare_pose_latents(
+ self,
+ pose_video: torch.Tensor,
+ batch_size: int = 1,
+ sample_mode: int = "argmax",
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ # pose_video shape: (B, C, T, H, W)
+ pose_video = pose_video.to(device=device, dtype=dtype if dtype is not None else self.vae.dtype)
+ if isinstance(generator, list):
+ pose_latents = [
+ retrieve_latents(self.vae.encode(pose_video), generator=g, sample_mode=sample_mode) for g in generator
+ ]
+ pose_latents = torch.cat(pose_latents)
+ else:
+ pose_latents = retrieve_latents(self.vae.encode(pose_video), generator, sample_mode)
+ # Standardize latents in preparation for Wan VAE encode
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(pose_latents.device, pose_latents.dtype)
+ )
+ latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ pose_latents.device, pose_latents.dtype
+ )
+ pose_latents = (pose_latents - latents_mean) * latents_recip_std
+ if pose_latents.shape[0] == 1 and batch_size > 1:
+ pose_latents = pose_latents.expand(batch_size, -1, -1, -1, -1)
+ return pose_latents
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 720,
+ width: int = 1280,
+ num_frames: int = 77,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames + 1, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents
+
+ def pad_video_frames(self, frames: List[Any], num_target_frames: int) -> List[Any]:
+ """
+ Pads an array-like video `frames` to `num_target_frames` using a "reflect"-like strategy. The frame dimension
+ is assumed to be the first dimension. In the 1D case, we can visualize this strategy as follows:
+
+ pad_video_frames([1, 2, 3, 4, 5], 10) -> [1, 2, 3, 4, 5, 4, 3, 2, 1, 2]
+ """
+ idx = 0
+ flip = False
+ target_frames = []
+ while len(target_frames) < num_target_frames:
+ target_frames.append(deepcopy(frames[idx]))
+ if flip:
+ idx -= 1
+ else:
+ idx += 1
+ if idx == 0 or idx == len(frames) - 1:
+ flip = not flip
+
+ return target_frames
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ pose_video: List[PIL.Image.Image],
+ face_video: List[PIL.Image.Image],
+ background_video: Optional[List[PIL.Image.Image]] = None,
+ mask_video: Optional[List[PIL.Image.Image]] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 720,
+ width: int = 1280,
+ segment_frame_length: int = 77,
+ num_inference_steps: int = 20,
+ mode: str = "animate",
+ prev_segment_conditioning_frames: int = 1,
+ motion_encode_batch_size: Optional[int] = None,
+ guidance_scale: float = 1.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input character image to condition the generation on. Must be an image, a list of images or a
+ `torch.Tensor`.
+ pose_video (`List[PIL.Image.Image]`):
+ The input pose video to condition the generation on. Must be a list of PIL images.
+ face_video (`List[PIL.Image.Image]`):
+ The input face video to condition the generation on. Must be a list of PIL images.
+ background_video (`List[PIL.Image.Image]`, *optional*):
+ When mode is `"replace"`, the input background video to condition the generation on. Must be a list of
+ PIL images.
+ mask_video (`List[PIL.Image.Image]`, *optional*):
+ When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of PIL
+ images.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ mode (`str`, defaults to `"animation"`):
+ The mode of the generation. Choose between `"animate"` and `"replace"`.
+ prev_segment_conditioning_frames (`int`, defaults to `1`):
+ The number of frames from the previous video segment to be used for temporal guidance. Recommended to
+ be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer.
+ motion_encode_batch_size (`int`, *optional*):
+ The batch size for batched encoding of the face video via the motion encoder. This allows trading off
+ inference speed for lower memory usage by setting a smaller batch size. Will default to
+ `self.transformer.config.motion_encoder_batch_size` if not set.
+ height (`int`, defaults to `720`):
+ The height of the generated video.
+ width (`int`, defaults to `1280`):
+ The width of the generated video.
+ segment_frame_length (`int`, defaults to `77`):
+ The number of frames in each generated video segment. The total frames of video generated will be equal
+ to the number of frames in `pose_video`; we will generate the video in segments until we have hit this
+ length. In general, should be 4N + 1, where N is a non-negative integer.
+ num_inference_steps (`int`, defaults to `20`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `1.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan
+ Animate inference.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ pose_video,
+ face_video,
+ background_video,
+ mask_video,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ mode,
+ prev_segment_conditioning_frames,
+ )
+
+ if segment_frame_length % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`segment_frame_length - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the"
+ f" nearest number."
+ )
+ segment_frame_length = (
+ segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ )
+ segment_frame_length = max(segment_frame_length, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # As we generate in segments of `segment_frame_length`, set the target frame length to be the least multiple
+ # of the effective segment length greater than or equal to the length of `pose_video`.
+ cond_video_frames = len(pose_video)
+ effective_segment_length = segment_frame_length - prev_segment_conditioning_frames
+ last_segment_frames = (cond_video_frames - prev_segment_conditioning_frames) % effective_segment_length
+ if last_segment_frames == 0:
+ num_padding_frames = 0
+ else:
+ num_padding_frames = effective_segment_length - last_segment_frames
+ num_target_frames = cond_video_frames + num_padding_frames
+ num_segments = num_target_frames // effective_segment_length
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Preprocess and encode the reference (character) image
+ image_height, image_width = self.video_processor.get_default_height_width(image)
+ if image_height != height or image_width != width:
+ logger.warning(f"Reshaping reference image from ({image_width}, {image_height}) to ({width}, {height})")
+ image_pixels = self.vae_image_processor.preprocess(image, height=height, width=width, resize_mode="fill").to(
+ device, dtype=torch.float32
+ )
+
+ # Get CLIP features from the reference image
+ if image_embeds is None:
+ image_embeds = self.encode_image(image, device)
+ image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 5. Encode conditioning videos (pose, face)
+ pose_video = self.pad_video_frames(pose_video, num_target_frames)
+ face_video = self.pad_video_frames(face_video, num_target_frames)
+
+ # TODO: also support np.ndarray input (e.g. from decord like the original implementation?)
+ pose_video_width, pose_video_height = pose_video[0].size
+ if pose_video_height != height or pose_video_width != width:
+ logger.warning(
+ f"Reshaping pose video from ({pose_video_width}, {pose_video_height}) to ({width}, {height})"
+ )
+ pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ face_video_width, face_video_height = face_video[0].size
+ expected_face_size = self.transformer.config.motion_encoder_size
+ if face_video_width != expected_face_size or face_video_height != expected_face_size:
+ logger.warning(
+ f"Reshaping face video from ({face_video_width}, {face_video_height}) to ({expected_face_size},"
+ f" {expected_face_size})"
+ )
+ face_video = self.video_processor.preprocess_video(
+ face_video, height=expected_face_size, width=expected_face_size
+ ).to(device, dtype=torch.float32)
+
+ if mode == "replace":
+ background_video = self.pad_video_frames(background_video, num_target_frames)
+ mask_video = self.pad_video_frames(mask_video, num_target_frames)
+
+ background_video = self.video_processor.preprocess_video(background_video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+ mask_video = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ # 6. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 7. Prepare latent variables which stay constant for all inference segments
+ num_channels_latents = self.vae.config.z_dim
+
+ # Get VAE-encoded latents of the reference (character) image
+ reference_image_latents = self.prepare_reference_image_latents(
+ image_pixels, batch_size * num_videos_per_prompt, generator=generator, device=device
+ )
+
+ # 8. Loop over video inference segments
+ start = 0
+ end = segment_frame_length # Data space frames, not latent frames
+ all_out_frames = []
+ out_frames = None
+
+ for _ in range(num_segments):
+ assert start + prev_segment_conditioning_frames < cond_video_frames
+
+ # Sample noisy latents from prior for the current inference segment
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ num_frames=segment_frame_length,
+ dtype=torch.float32,
+ device=device,
+ generator=generator,
+ latents=latents if start == 0 else None, # Only use pre-calculated latents for first segment
+ )
+
+ pose_video_segment = pose_video[:, :, start:end]
+ face_video_segment = face_video[:, :, start:end]
+
+ face_video_segment = face_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1)
+ face_video_segment = face_video_segment.to(dtype=transformer_dtype)
+
+ if start > 0:
+ prev_segment_cond_video = out_frames[:, :, -prev_segment_conditioning_frames:].clone().detach()
+ else:
+ prev_segment_cond_video = None
+
+ if mode == "replace":
+ background_video_segment = background_video[:, :, start:end]
+ mask_video_segment = mask_video[:, :, start:end]
+
+ background_video_segment = background_video_segment.expand(
+ batch_size * num_videos_per_prompt, -1, -1, -1, -1
+ )
+ mask_video_segment = mask_video_segment.expand(batch_size * num_videos_per_prompt, -1, -1, -1, -1)
+ else:
+ background_video_segment = None
+ mask_video_segment = None
+
+ pose_latents = self.prepare_pose_latents(
+ pose_video_segment, batch_size * num_videos_per_prompt, generator=generator, device=device
+ )
+ pose_latents = pose_latents.to(dtype=transformer_dtype)
+
+ prev_segment_cond_latents = self.prepare_prev_segment_cond_latents(
+ prev_segment_cond_video,
+ background_video=background_video_segment,
+ mask_video=mask_video_segment,
+ batch_size=batch_size * num_videos_per_prompt,
+ segment_frame_length=segment_frame_length,
+ start_frame=start,
+ height=height,
+ width=width,
+ prev_segment_cond_frames=prev_segment_conditioning_frames,
+ task=mode,
+ generator=generator,
+ device=device,
+ )
+
+ # Concatenate the reference latents in the frame dimension
+ reference_latents = torch.cat([reference_image_latents, prev_segment_cond_latents], dim=2)
+
+ # 8.1 Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ # Concatenate the reference image + prev segment conditioning in the channel dim
+ latent_model_input = torch.cat([latents, reference_latents], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ pose_hidden_states=pose_latents,
+ face_pixel_values=face_video_segment,
+ motion_encode_batch_size=motion_encode_batch_size,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ # Blank out face for unconditional guidance (set all pixels to -1)
+ face_pixel_values_uncond = face_video_segment * 0 - 1
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ pose_hidden_states=pose_latents,
+ face_pixel_values=face_pixel_values_uncond,
+ motion_encode_batch_size=motion_encode_batch_size,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ latents = latents.to(self.vae.dtype)
+ # Destandardize latents in preparation for Wan VAE decoding
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
+ 1, self.vae.config.z_dim, 1, 1, 1
+ ).to(latents.device, latents.dtype)
+ latents = latents / latents_recip_std + latents_mean
+ # Skip the first latent frame (used for conditioning)
+ out_frames = self.vae.decode(latents[:, :, 1:], return_dict=False)[0]
+
+ if start > 0:
+ out_frames = out_frames[:, :, prev_segment_conditioning_frames:]
+ all_out_frames.append(out_frames)
+
+ start += effective_segment_length
+ end += effective_segment_length
+
+ # Reset scheduler timesteps / state for next denoising loop
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ self._current_timestep = None
+ assert start + prev_segment_conditioning_frames >= cond_video_frames
+
+ if not output_type == "latent":
+ video = torch.cat(all_out_frames, dim=2)[:, :, :cond_video_frames]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
index fd1d90849a66..b7fd0b05980f 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
@@ -15,7 +15,6 @@
import html
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-import ftfy
import PIL
import regex as re
import torch
@@ -26,7 +25,7 @@
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -42,6 +41,9 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+if is_ftfy_available():
+ import ftfy
+
EXAMPLE_DOC_STRING = """
Examples:
```python
@@ -147,20 +149,33 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
+ `transformer` is used.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
- model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
- image_encoder: CLIPVisionModel,
- image_processor: CLIPImageProcessor,
- transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
+ image_processor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModel = None,
+ transformer: WanTransformer3DModel = None,
+ transformer_2: WanTransformer3DModel = None,
+ boundary_ratio: Optional[float] = None,
+ expand_timesteps: bool = False,
):
super().__init__()
@@ -172,10 +187,12 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
image_processor=image_processor,
+ transformer_2=transformer_2,
)
+ self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
- self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
- self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.image_processor = image_processor
@@ -321,10 +338,21 @@ def check_inputs(
width,
prompt_embeds=None,
negative_prompt_embeds=None,
+ image_embeds=None,
callback_on_step_end_tensor_inputs=None,
+ guidance_scale_2=None,
):
- if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
- raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -356,6 +384,12 @@ def check_inputs(
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
+
+ if self.config.boundary_ratio is not None and image_embeds is not None:
+ raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
+
def prepare_latents(
self,
image: PipelineImageInput,
@@ -368,6 +402,7 @@ def prepare_latents(
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
@@ -385,11 +420,22 @@ def prepare_latents(
else:
latents = latents.to(device=device, dtype=dtype)
- image = image.unsqueeze(2)
- video_condition = torch.cat(
- [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
- )
- video_condition = video_condition.to(device=device, dtype=dtype)
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
+
+ if self.config.expand_timesteps:
+ video_condition = image
+
+ elif last_image is None:
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
+ )
+ else:
+ last_image = last_image.unsqueeze(2)
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
+ dim=2,
+ )
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
@@ -409,10 +455,22 @@ def prepare_latents(
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+ latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std
+ if self.config.expand_timesteps:
+ first_frame_mask = torch.ones(
+ 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
+ )
+ first_frame_mask[:, :, 0] = 0
+ return latents, latent_condition, first_frame_mask
+
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
- mask_lat_size[:, :, list(range(1, num_frames))] = 0
+
+ if last_image is None:
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ else:
+ mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
@@ -458,11 +516,14 @@ def __call__(
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -495,11 +556,15 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -512,7 +577,13 @@ def __call__(
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
- output_type (`str`, *optional*, defaults to `"pil"`):
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
@@ -529,12 +600,10 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
- max_sequence_length (`int`, *optional*, defaults to `512`):
- The maximum sequence length of the prompt.
- shift (`float`, *optional*, defaults to `5.0`):
- The shift of the flow.
- autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
- The dtype to use for the torch.amp.autocast.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
Examples:
Returns:
@@ -556,7 +625,9 @@ def __call__(
width,
prompt_embeds,
negative_prompt_embeds,
+ image_embeds,
callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -566,7 +637,11 @@ def __call__(
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -594,14 +669,20 @@ def __call__(
)
# Encode image embedding
- transformer_dtype = self.transformer.dtype
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
- image_embeds = self.encode_image(image, device)
- image_embeds = image_embeds.repeat(batch_size, 1, 1)
- image_embeds = image_embeds.to(transformer_dtype)
+ # only wan 2.1 i2v transformer accepts image_embeds
+ if self.transformer is not None and self.transformer.config.image_dim is not None:
+ if image_embeds is None:
+ if last_image is None:
+ image_embeds = self.encode_image(image, device)
+ else:
+ image_embeds = self.encode_image([image, last_image], device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -610,7 +691,12 @@ def __call__(
# 5. Prepare latent variables
num_channels_latents = self.vae.config.z_dim
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
- latents, condition = self.prepare_latents(
+ if last_image is not None:
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ latents_outputs = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -621,40 +707,72 @@ def __call__(
device,
generator,
latents,
+ last_image,
)
+ if self.config.expand_timesteps:
+ # wan 2.2 5b i2v use firt_frame_mask to mask timesteps
+ latents, condition, first_frame_mask = latents_outputs
+ else:
+ latents, condition = latents_outputs
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
- latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
- timestep = t.expand(latents.shape[0])
-
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- encoder_hidden_states_image=image_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
+
+ if self.config.expand_timesteps:
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(transformer_dtype)
+
+ # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ # batch_size, seq_len
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -678,6 +796,9 @@ def __call__(
self._current_timestep = None
+ if self.config.expand_timesteps:
+ latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
+
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py
new file mode 100644
index 000000000000..351ae2e70563
--- /dev/null
+++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py
@@ -0,0 +1,1045 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import PIL.Image
+import regex as re
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanVACETransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import WanPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> import PIL.Image
+ >>> from diffusers import AutoencoderKLWan, WanVACEPipeline
+ >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
+ >>> from diffusers.utils import export_to_video, load_image
+ def prepare_video_and_mask(first_img: PIL.Image.Image, last_img: PIL.Image.Image, height: int, width: int, num_frames: int):
+ first_img = first_img.resize((width, height))
+ last_img = last_img.resize((width, height))
+ frames = []
+ frames.append(first_img)
+ # Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays
+ # whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to
+ # match the original code.
+ frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
+ frames.append(last_img)
+ mask_black = PIL.Image.new("L", (width, height), 0)
+ mask_white = PIL.Image.new("L", (width, height), 255)
+ mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
+ return frames, mask
+
+ >>> # Available checkpoints: Wan-AI/Wan2.1-VACE-1.3B-diffusers, Wan-AI/Wan2.1-VACE-14B-diffusers
+ >>> model_id = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe.to("cuda")
+
+ >>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+ >>> first_frame = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
+ ... )
+ >>> last_frame = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png>>> "
+ ... )
+
+ >>> height = 512
+ >>> width = 512
+ >>> num_frames = 81
+ >>> video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames)
+
+ >>> output = pipe(
+ ... video=video,
+ ... mask=mask,
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=height,
+ ... width=width,
+ ... num_frames=num_frames,
+ ... num_inference_steps=30,
+ ... guidance_scale=5.0,
+ ... generator=torch.Generator().manual_seed(42),
+ ... ).frames[0]
+ >>> export_to_video(output, "output.mp4", fps=16)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for controllable generation using Wan.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ transformer ([`WanVACETransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
+ `transformer` or `transformer_2` must be provided.
+ transformer_2 ([`WanVACETransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
+ `transformer` or `transformer_2` must be provided.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer", "transformer_2"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer: WanVACETransformer3DModel = None,
+ transformer_2: WanVACETransformer3DModel = None,
+ boundary_ratio: Optional[float] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ transformer_2=transformer_2,
+ scheduler=scheduler,
+ )
+ self.register_to_config(boundary_ratio=boundary_ratio)
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ video=None,
+ mask=None,
+ reference_images=None,
+ guidance_scale_2=None,
+ ):
+ if self.transformer is not None:
+ base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
+ elif self.transformer_2 is not None:
+ base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
+ else:
+ raise ValueError(
+ "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
+ )
+
+ if height % base != 0 or width % base != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if video is not None:
+ if mask is not None:
+ if len(video) != len(mask):
+ raise ValueError(
+ f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that"
+ " they have the same length."
+ )
+ if reference_images is not None:
+ is_pil_image = isinstance(reference_images, PIL.Image.Image)
+ is_list_of_pil_images = isinstance(reference_images, list) and all(
+ isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images
+ )
+ is_list_of_list_of_pil_images = isinstance(reference_images, list) and all(
+ isinstance(ref_img, list) and all(isinstance(ref_img_, PIL.Image.Image) for ref_img_ in ref_img)
+ for ref_img in reference_images
+ )
+ if not (is_pil_image or is_list_of_pil_images or is_list_of_list_of_pil_images):
+ raise ValueError(
+ "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
+ "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
+ )
+ if is_list_of_list_of_pil_images and len(reference_images) != 1:
+ raise ValueError(
+ "The pipeline only supports generating one video at a time at the moment. When passing a list "
+ "of list of reference images, where the outer list corresponds to the batch size and the inner "
+ "list corresponds to list of conditioning images per video, please make sure to only pass "
+ "one inner list of reference images (i.e., `[[, , ...]]`"
+ )
+ elif mask is not None:
+ raise ValueError("`mask` can only be passed if `video` is passed as well.")
+
+ def preprocess_conditions(
+ self,
+ video: Optional[List[PipelineImageInput]] = None,
+ mask: Optional[List[PipelineImageInput]] = None,
+ reference_images: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
+ batch_size: int = 1,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ if video is not None:
+ base = self.vae_scale_factor_spatial * (
+ self.transformer.config.patch_size[1]
+ if self.transformer is not None
+ else self.transformer_2.config.patch_size[1]
+ )
+ video_height, video_width = self.video_processor.get_default_height_width(video[0])
+
+ if video_height * video_width > height * width:
+ scale = min(width / video_width, height / video_height)
+ video_height, video_width = int(video_height * scale), int(video_width * scale)
+
+ if video_height % base != 0 or video_width % base != 0:
+ logger.warning(
+ f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. "
+ )
+ video_height = (video_height // base) * base
+ video_width = (video_width // base) * base
+
+ assert video_height * video_width <= height * width
+
+ video = self.video_processor.preprocess_video(video, video_height, video_width)
+ image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling)
+ else:
+ video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device)
+ image_size = (height, width) # Use the height/width provider by user
+
+ if mask is not None:
+ mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1])
+ mask = torch.clamp((mask + 1) / 2, min=0, max=1)
+ else:
+ mask = torch.ones_like(video)
+
+ video = video.to(dtype=dtype, device=device)
+ mask = mask.to(dtype=dtype, device=device)
+
+ # Make a list of list of images where the outer list corresponds to video batch size and the inner list
+ # corresponds to list of conditioning images per video
+ if reference_images is None or isinstance(reference_images, PIL.Image.Image):
+ reference_images = [[reference_images] for _ in range(video.shape[0])]
+ elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image):
+ reference_images = [reference_images]
+ elif (
+ isinstance(reference_images, (list, tuple))
+ and isinstance(next(iter(reference_images)), list)
+ and isinstance(next(iter(reference_images[0])), PIL.Image.Image)
+ ):
+ reference_images = reference_images
+ else:
+ raise ValueError(
+ "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
+ "`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
+ )
+
+ if video.shape[0] != len(reference_images):
+ raise ValueError(
+ f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
+ )
+
+ ref_images_lengths = [len(reference_images_batch) for reference_images_batch in reference_images]
+ if any(l != ref_images_lengths[0] for l in ref_images_lengths):
+ raise ValueError(
+ f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}. Support for this "
+ "may be added in the future."
+ )
+
+ reference_images_preprocessed = []
+ for i, reference_images_batch in enumerate(reference_images):
+ preprocessed_images = []
+ for j, image in enumerate(reference_images_batch):
+ if image is None:
+ continue
+ image = self.video_processor.preprocess(image, None, None)
+ img_height, img_width = image.shape[-2:]
+ scale = min(image_size[0] / img_height, image_size[1] / img_width)
+ new_height, new_width = int(img_height * scale), int(img_width * scale)
+ resized_image = torch.nn.functional.interpolate(
+ image, size=(new_height, new_width), mode="bilinear", align_corners=False
+ ).squeeze(0) # [C, H, W]
+ top = (image_size[0] - new_height) // 2
+ left = (image_size[1] - new_width) // 2
+ canvas = torch.ones(3, *image_size, device=device, dtype=dtype)
+ canvas[:, top : top + new_height, left : left + new_width] = resized_image
+ preprocessed_images.append(canvas)
+ reference_images_preprocessed.append(preprocessed_images)
+
+ return video, mask, reference_images_preprocessed
+
+ def prepare_video_latents(
+ self,
+ video: torch.Tensor,
+ mask: torch.Tensor,
+ reference_images: Optional[List[List[torch.Tensor]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ device = device or self._execution_device
+
+ if isinstance(generator, list):
+ # TODO: support this
+ raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
+
+ if reference_images is None:
+ # For each batch of video, we set no re
+ # ference image (as one or more can be passed by user)
+ reference_images = [[None] for _ in range(video.shape[0])]
+ else:
+ if video.shape[0] != len(reference_images):
+ raise ValueError(
+ f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
+ )
+
+ if video.shape[0] != 1:
+ # TODO: support this
+ raise ValueError(
+ "Generating with more than one video is not yet supported. This may be supported in the future."
+ )
+
+ vae_dtype = self.vae.dtype
+ video = video.to(dtype=vae_dtype)
+
+ latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view(
+ 1, self.vae.config.z_dim, 1, 1, 1
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view(
+ 1, self.vae.config.z_dim, 1, 1, 1
+ )
+
+ if mask is None:
+ latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
+ latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
+ else:
+ mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
+ inactive = video * (1 - mask)
+ reactive = video * mask
+ inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
+ reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
+ inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype)
+ reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype)
+ latents = torch.cat([inactive, reactive], dim=1)
+
+ latent_list = []
+ for latent, reference_images_batch in zip(latents, reference_images):
+ for reference_image in reference_images_batch:
+ assert reference_image.ndim == 3
+ reference_image = reference_image.to(dtype=vae_dtype)
+ reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
+ reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
+ reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype)
+ reference_latent = reference_latent.squeeze(0) # [C, 1, H, W]
+ reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=0)
+ latent = torch.cat([reference_latent.squeeze(0), latent], dim=1)
+ latent_list.append(latent)
+ return torch.stack(latent_list)
+
+ def prepare_masks(
+ self,
+ mask: torch.Tensor,
+ reference_images: Optional[List[torch.Tensor]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ ) -> torch.Tensor:
+ if isinstance(generator, list):
+ # TODO: support this
+ raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
+
+ if reference_images is None:
+ # For each batch of video, we set no reference image (as one or more can be passed by user)
+ reference_images = [[None] for _ in range(mask.shape[0])]
+ else:
+ if mask.shape[0] != len(reference_images):
+ raise ValueError(
+ f"Batch size of `mask` {mask.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
+ )
+
+ if mask.shape[0] != 1:
+ # TODO: support this
+ raise ValueError(
+ "Generating with more than one video is not yet supported. This may be supported in the future."
+ )
+
+ transformer_patch_size = (
+ self.transformer.config.patch_size[1]
+ if self.transformer is not None
+ else self.transformer_2.config.patch_size[1]
+ )
+
+ mask_list = []
+ for mask_, reference_images_batch in zip(mask, reference_images):
+ num_channels, num_frames, height, width = mask_.shape
+ new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal
+ new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
+ new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
+ mask_ = mask_[0, :, :, :]
+ mask_ = mask_.view(
+ num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial
+ )
+ mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width]
+ mask_ = torch.nn.functional.interpolate(
+ mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact"
+ ).squeeze(0)
+ num_ref_images = len(reference_images_batch)
+ if num_ref_images > 0:
+ mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :])
+ mask_ = torch.cat([mask_padding, mask_], dim=1)
+ mask_list.append(mask_)
+ return torch.stack(mask_list)
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ video: Optional[List[PipelineImageInput]] = None,
+ mask: Optional[List[PipelineImageInput]] = None,
+ reference_images: Optional[List[PipelineImageInput]] = None,
+ conditioning_scale: Union[float, List[float], torch.Tensor] = 1.0,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ video (`List[PIL.Image.Image]`, *optional*):
+ The input video or videos to be used as a starting point for the generation. The video should be a list
+ of PIL images, a numpy array, or a torch tensor. Currently, the pipeline only supports generating one
+ video at a time.
+ mask (`List[PIL.Image.Image]`, *optional*):
+ The input mask defines which video regions to condition on and which to generate. Black areas in the
+ mask indicate conditioning regions, while white areas indicate regions for generation. The mask should
+ be a list of PIL images, a numpy array, or a torch tensor. Currently supports generating a single video
+ at a time.
+ reference_images (`List[PIL.Image.Image]`, *optional*):
+ A list of one or more reference images as extra conditioning for the generation. For example, if you
+ are trying to inpaint a video to change the character, you can pass reference images of the new
+ character here. Refer to the Diffusers [examples](https://github.com/huggingface/diffusers/pull/11582)
+ and original [user
+ guide](https://github.com/ali-vilab/VACE/blob/0897c6d055d7d9ea9e191dce763006664d9780f8/UserGuide.md)
+ for a full list of supported tasks and use cases.
+ conditioning_scale (`float`, `List[float]`, `torch.Tensor`, defaults to `1.0`):
+ The conditioning scale to be applied when adding the control conditioning latent stream to the
+ denoising latent stream in each control layer of the model. If a float is provided, it will be applied
+ uniformly to all layers. If a list or tensor is provided, it should have the same length as the number
+ of control layers in the model (`len(transformer.config.vace_layers)`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # Simplification of implementation for now
+ if prompt is not None and not isinstance(prompt, str):
+ raise ValueError("Passing a list of prompts is not yet supported. This may be supported in the future.")
+ if num_videos_per_prompt != 1:
+ raise ValueError(
+ "Generating multiple videos per prompt is not yet supported. This may be supported in the future."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ video,
+ mask,
+ reference_images,
+ guidance_scale_2,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
+ self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ vae_dtype = self.vae.dtype
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
+
+ vace_layers = (
+ self.transformer.config.vace_layers
+ if self.transformer is not None
+ else self.transformer_2.config.vace_layers
+ )
+ if isinstance(conditioning_scale, (int, float)):
+ conditioning_scale = [conditioning_scale] * len(vace_layers)
+ if isinstance(conditioning_scale, list):
+ if len(conditioning_scale) != len(vace_layers):
+ raise ValueError(
+ f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
+ )
+ conditioning_scale = torch.tensor(conditioning_scale)
+ if isinstance(conditioning_scale, torch.Tensor):
+ if conditioning_scale.size(0) != len(vace_layers):
+ raise ValueError(
+ f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
+ )
+ conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ video, mask, reference_images = self.preprocess_conditions(
+ video,
+ mask,
+ reference_images,
+ batch_size,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ )
+ num_reference_images = len(reference_images[0])
+
+ conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device)
+ mask = self.prepare_masks(mask, reference_images, generator)
+ conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
+ conditioning_latents = conditioning_latents.to(transformer_dtype)
+
+ num_channels_latents = (
+ self.transformer.config.in_channels
+ if self.transformer is not None
+ else self.transformer_2.config.in_channels
+ )
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames + num_reference_images * self.vae_scale_factor_temporal,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ if conditioning_latents.shape[2] != latents.shape[2]:
+ logger.warning(
+ "The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected."
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
+
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ control_hidden_states=conditioning_latents,
+ control_hidden_states_scale=conditioning_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ control_hidden_states=conditioning_latents,
+ control_hidden_states_scale=conditioning_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents[:, :, num_reference_images:]
+ latents = latents.to(vae_dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
index c72dd7f5f1eb..a976126da7fe 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
@@ -16,7 +16,6 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
-import ftfy
import regex as re
import torch
from PIL import Image
@@ -26,7 +25,7 @@
from ...loaders import WanLoraLoaderMixin
from ...models import AutoencoderKLWan, WanTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -42,12 +41,15 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+if is_ftfy_available():
+ import ftfy
+
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
- >>> from diffusers.utils import export_to_video
+ >>> from diffusers.utils import export_to_video, load_video
>>> from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
@@ -417,12 +419,7 @@ def prepare_latents(
)
if latents is None:
- if isinstance(generator, list):
- init_latents = [
- retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
- ]
- else:
- init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
+ init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
@@ -439,7 +436,7 @@ def prepare_latents(
if hasattr(self.scheduler, "add_noise"):
latents = self.scheduler.add_noise(init_latents, noise, timestep)
else:
- latents = self.scheduelr.scale_noise(init_latents, timestep, noise)
+ latents = self.scheduler.scale_noise(init_latents, timestep, noise)
else:
latents = latents.to(device)
@@ -511,7 +508,7 @@ def __call__(
Args:
prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
instead.
height (`int`, defaults to `480`):
The height in pixels of the generated image.
@@ -523,11 +520,13 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ strength (`float`, defaults to `0.8`):
+ Higher strength leads to more differences between original image and generated video.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -540,7 +539,7 @@ def __call__(
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
- output_type (`str`, *optional*, defaults to `"pil"`):
+ output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
@@ -557,8 +556,9 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
- autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
- The dtype to use for the torch.amp.autocast.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
Examples:
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py b/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py
index b2cf8cbc978c..5ab206b15176 100644
--- a/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py
+++ b/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py
@@ -1,5 +1,5 @@
# Copyright (c) 2022 Dominic Rampas MIT License
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py
index 6c06cc0e7303..77ae597655d1 100644
--- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py
+++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Dominic Rampas MIT License
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
index 9863c506d743..dbdd50871b43 100644
--- a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Dominic Rampas MIT License
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,17 +14,16 @@
# limitations under the License.
import math
-from typing import Dict, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
+from ...models.attention import AttentionMixin
from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
- AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
@@ -32,7 +31,7 @@
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
-class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+class WuerstchenPrior(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
unet_name = "prior"
_supports_gradient_checkpointing = True
@@ -61,66 +60,6 @@ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dro
self.gradient_checkpointing = False
self.set_default_attn_processor()
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
index edc01f0d5c75..bbdb60471fd1 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,7 @@
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
@@ -56,7 +56,7 @@
"""
-class WuerstchenDecoderPipeline(DiffusionPipeline):
+class WuerstchenDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for generating images from the Wuerstchen model.
@@ -247,11 +247,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 0.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
- `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
- linked to the text `prompt`, usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
+ setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
@@ -263,7 +263,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
index 7819c8c0a0ef..c54c1fefe8fe 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,7 @@
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import deprecate, replace_example_docstring
-from ..pipeline_utils import DiffusionPipeline
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
from .modeling_wuerstchen_prior import WuerstchenPrior
@@ -40,7 +40,7 @@
"""
-class WuerstchenCombinedPipeline(DiffusionPipeline):
+class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Combined Pipeline for text-to-image generation using Wuerstchen
@@ -68,6 +68,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
The scheduler to be used for prior pipeline.
"""
+ _last_supported_version = "0.33.1"
_load_connected_pipes = True
def __init__(
@@ -112,7 +113,7 @@ def __init__(
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
@@ -122,7 +123,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
- def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = None):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
@@ -190,11 +191,11 @@ def __call__(
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
- `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
- to the text `prompt`, usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `prior_guidance_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
+ setting `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized
@@ -210,18 +211,18 @@ def __call__(
Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced
`num_inference_steps` timesteps are used. Must be in descending order.
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
index 8f6ba419721d..e138b6e805c8 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -325,11 +325,11 @@ def __call__(
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 8.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
- `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
- linked to the text `prompt`, usually at the expense of lower image quality.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `decoder_guidance_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by
+ setting `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are
+ closely linked to the text `prompt`, usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `decoder_guidance_scale` is less than `1`).
@@ -348,7 +348,7 @@ def __call__(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/z_image/__init__.py b/src/diffusers/pipelines/z_image/__init__.py
new file mode 100644
index 000000000000..f4342713e3e9
--- /dev/null
+++ b/src/diffusers/pipelines/z_image/__init__.py
@@ -0,0 +1,52 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa: F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_output"] = ["ZImagePipelineOutput"]
+ _import_structure["pipeline_z_image"] = ["ZImagePipeline"]
+ _import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_output import ZImagePipelineOutput
+ from .pipeline_z_image import ZImagePipeline
+ from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/z_image/pipeline_output.py b/src/diffusers/pipelines/z_image/pipeline_output.py
new file mode 100644
index 000000000000..69a320fc036a
--- /dev/null
+++ b/src/diffusers/pipelines/z_image/pipeline_output.py
@@ -0,0 +1,35 @@
+# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class ZImagePipelineOutput(BaseOutput):
+ """
+ Output class for Z-Image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image.py b/src/diffusers/pipelines/z_image/pipeline_z_image.py
new file mode 100644
index 000000000000..82bdd7d361b7
--- /dev/null
+++ b/src/diffusers/pipelines/z_image/pipeline_z_image.py
@@ -0,0 +1,594 @@
+# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import AutoTokenizer, PreTrainedModel
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import ZImageTransformer2DModel
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from .pipeline_output import ZImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import ZImagePipeline
+
+ >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
+ >>> # (1) Use flash attention 2
+ >>> # pipe.transformer.set_attention_backend("flash")
+ >>> # (2) Use flash attention 3
+ >>> # pipe.transformer.set_attention_backend("_flash_3")
+
+ >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。"
+ >>> image = pipe(
+ ... prompt,
+ ... height=1024,
+ ... width=1024,
+ ... num_inference_steps=9,
+ ... guidance_scale=0.0,
+ ... generator=torch.Generator("cuda").manual_seed(42),
+ ... ).images[0]
+ >>> image.save("zimage.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: PreTrainedModel,
+ tokenizer: AutoTokenizer,
+ transformer: ZImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ transformer=transformer,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ ):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt_embeds = self._encode_prompt(
+ prompt=prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if do_classifier_free_guidance:
+ if negative_prompt is None:
+ negative_prompt = ["" for _ in prompt]
+ else:
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ assert len(prompt) == len(negative_prompt)
+ negative_prompt_embeds = self._encode_prompt(
+ prompt=negative_prompt,
+ device=device,
+ prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ )
+ else:
+ negative_prompt_embeds = []
+ return prompt_embeds, negative_prompt_embeds
+
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ max_sequence_length: int = 512,
+ ) -> List[torch.FloatTensor]:
+ device = device or self._execution_device
+
+ if prompt_embeds is not None:
+ return prompt_embeds
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ for i, prompt_item in enumerate(prompt):
+ messages = [
+ {"role": "user", "content": prompt_item},
+ ]
+ prompt_item = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True,
+ )
+ prompt[i] = prompt_item
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device)
+ prompt_masks = text_inputs.attention_mask.to(device).bool()
+
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_masks,
+ output_hidden_states=True,
+ ).hidden_states[-2]
+
+ embeddings_list = []
+
+ for i in range(len(prompt_embeds)):
+ embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
+
+ return embeddings_list
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 5.0,
+ cfg_normalization: bool = False,
+ cfg_truncation: float = 1.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to 1024):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 1024):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ cfg_normalization (`bool`, *optional*, defaults to False):
+ Whether to apply configuration normalization.
+ cfg_truncation (`float`, *optional*, defaults to 1.0):
+ The truncation value for configuration.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
+ tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ height = height or 1024
+ width = width or 1024
+
+ vae_scale = self.vae_scale_factor * 2
+ if height % vae_scale != 0:
+ raise ValueError(
+ f"Height must be divisible by {vae_scale} (got {height}). "
+ f"Please adjust the height to a multiple of {vae_scale}."
+ )
+ if width % vae_scale != 0:
+ raise ValueError(
+ f"Width must be divisible by {vae_scale} (got {width}). "
+ f"Please adjust the width to a multiple of {vae_scale}."
+ )
+
+ device = self._execution_device
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+ self._cfg_normalization = cfg_normalization
+ self._cfg_truncation = cfg_truncation
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = len(prompt_embeds)
+
+ # If prompt_embeds is provided and prompt is None, skip encoding
+ if prompt_embeds is not None and prompt is None:
+ if self.do_classifier_free_guidance and negative_prompt_embeds is None:
+ raise ValueError(
+ "When `prompt_embeds` is provided without `prompt`, "
+ "`negative_prompt_embeds` must also be provided for classifier-free guidance."
+ )
+ else:
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.in_channels
+
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # Repeat prompt_embeds for num_images_per_prompt
+ if num_images_per_prompt > 1:
+ prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
+ if self.do_classifier_free_guidance and negative_prompt_embeds:
+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
+
+ actual_batch_size = batch_size * num_images_per_prompt
+ image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
+
+ # 5. Prepare timesteps
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ self.scheduler.sigma_min = 0.0
+ scheduler_kwargs = {"mu": mu}
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+ timestep = (1000 - timestep) / 1000
+ # Normalized time for time-aware config (0 at start, 1 at end)
+ t_norm = timestep[0].item()
+
+ # Handle cfg truncation
+ current_guidance_scale = self.guidance_scale
+ if (
+ self.do_classifier_free_guidance
+ and self._cfg_truncation is not None
+ and float(self._cfg_truncation) <= 1
+ ):
+ if t_norm > self._cfg_truncation:
+ current_guidance_scale = 0.0
+
+ # Run CFG only if configured AND scale is non-zero
+ apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
+
+ if apply_cfg:
+ latents_typed = latents.to(self.transformer.dtype)
+ latent_model_input = latents_typed.repeat(2, 1, 1, 1)
+ prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
+ timestep_model_input = timestep.repeat(2)
+ else:
+ latent_model_input = latents.to(self.transformer.dtype)
+ prompt_embeds_model_input = prompt_embeds
+ timestep_model_input = timestep
+
+ latent_model_input = latent_model_input.unsqueeze(2)
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
+
+ model_out_list = self.transformer(
+ latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False
+ )[0]
+
+ if apply_cfg:
+ # Perform CFG
+ pos_out = model_out_list[:actual_batch_size]
+ neg_out = model_out_list[actual_batch_size:]
+
+ noise_pred = []
+ for j in range(actual_batch_size):
+ pos = pos_out[j].float()
+ neg = neg_out[j].float()
+
+ pred = pos + current_guidance_scale * (pos - neg)
+
+ # Renormalization
+ if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
+ ori_pos_norm = torch.linalg.vector_norm(pos)
+ new_pos_norm = torch.linalg.vector_norm(pred)
+ max_new_norm = ori_pos_norm * float(self._cfg_normalization)
+ if new_pos_norm > max_new_norm:
+ pred = pred * (max_new_norm / new_pos_norm)
+
+ noise_pred.append(pred)
+
+ noise_pred = torch.stack(noise_pred, dim=0)
+ else:
+ noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
+
+ noise_pred = noise_pred.squeeze(2)
+ noise_pred = -noise_pred
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
+ assert latents.dtype == torch.float32
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = latents.to(self.vae.dtype)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ZImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py b/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py
new file mode 100644
index 000000000000..2b3e80a2082b
--- /dev/null
+++ b/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py
@@ -0,0 +1,709 @@
+# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import AutoTokenizer, PreTrainedModel
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
+from ...models.autoencoders import AutoencoderKL
+from ...models.transformers import ZImageTransformer2DModel
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from .pipeline_output import ZImagePipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import ZImageImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = ZImageImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ >>> init_image = load_image(url).resize((1024, 1024))
+
+ >>> prompt = "A fantasy landscape with mountains and a river, detailed, vibrant colors"
+ >>> image = pipe(
+ ... prompt,
+ ... image=init_image,
+ ... strength=0.6,
+ ... num_inference_steps=9,
+ ... guidance_scale=0.0,
+ ... generator=torch.Generator("cuda").manual_seed(42),
+ ... ).images[0]
+ >>> image.save("zimage_img2img.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class ZImageImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
+ r"""
+ The ZImage pipeline for image-to-image generation.
+
+ Args:
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`PreTrainedModel`]):
+ A text encoder model to encode text prompts.
+ tokenizer ([`AutoTokenizer`]):
+ A tokenizer to tokenize text prompts.
+ transformer ([`ZImageTransformer2DModel`]):
+ A ZImage transformer model to denoise the encoded image latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: PreTrainedModel,
+ tokenizer: AutoTokenizer,
+ transformer: ZImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ scheduler=scheduler,
+ transformer=transformer,
+ )
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ ):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt_embeds = self._encode_prompt(
+ prompt=prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ )
+
+ if do_classifier_free_guidance:
+ if negative_prompt is None:
+ negative_prompt = ["" for _ in prompt]
+ else:
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ assert len(prompt) == len(negative_prompt)
+ negative_prompt_embeds = self._encode_prompt(
+ prompt=negative_prompt,
+ device=device,
+ prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ )
+ else:
+ negative_prompt_embeds = []
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ max_sequence_length: int = 512,
+ ) -> List[torch.FloatTensor]:
+ device = device or self._execution_device
+
+ if prompt_embeds is not None:
+ return prompt_embeds
+
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ for i, prompt_item in enumerate(prompt):
+ messages = [
+ {"role": "user", "content": prompt_item},
+ ]
+ prompt_item = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ enable_thinking=True,
+ )
+ prompt[i] = prompt_item
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device)
+ prompt_masks = text_inputs.attention_mask.to(device).bool()
+
+ prompt_embeds = self.text_encoder(
+ input_ids=text_input_ids,
+ attention_mask=prompt_masks,
+ output_hidden_states=True,
+ ).hidden_states[-2]
+
+ embeddings_list = []
+
+ for i in range(len(prompt_embeds)):
+ embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
+
+ return embeddings_list
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ # Encode the input image
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != num_channels_latents:
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ # Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+ else:
+ image_latents = image
+
+ # Handle batch size expansion
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+
+ # Add noise using flow matching scale_noise
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: PipelineImageInput = None,
+ strength: float = 0.6,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 5.0,
+ cfg_normalization: bool = False,
+ cfg_truncation: float = 1.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for image-to-image generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
+ list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
+ a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
+ strength (`float`, *optional*, defaults to 0.6):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ height (`int`, *optional*, defaults to 1024):
+ The height in pixels of the generated image. If not provided, uses the input image height.
+ width (`int`, *optional*, defaults to 1024):
+ The width in pixels of the generated image. If not provided, uses the input image width.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ cfg_normalization (`bool`, *optional*, defaults to False):
+ Whether to apply configuration normalization.
+ cfg_truncation (`float`, *optional*, defaults to 1.0):
+ The truncation value for configuration.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
+ tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ # 1. Check inputs and validate strength
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
+
+ # 2. Preprocess image
+ init_image = self.image_processor.preprocess(image)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # Get dimensions from the preprocessed image if not specified
+ if height is None:
+ height = init_image.shape[-2]
+ if width is None:
+ width = init_image.shape[-1]
+
+ vae_scale = self.vae_scale_factor * 2
+ if height % vae_scale != 0:
+ raise ValueError(
+ f"Height must be divisible by {vae_scale} (got {height}). "
+ f"Please adjust the height to a multiple of {vae_scale}."
+ )
+ if width % vae_scale != 0:
+ raise ValueError(
+ f"Width must be divisible by {vae_scale} (got {width}). "
+ f"Please adjust the width to a multiple of {vae_scale}."
+ )
+
+ device = self._execution_device
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+ self._cfg_normalization = cfg_normalization
+ self._cfg_truncation = cfg_truncation
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = len(prompt_embeds)
+
+ # If prompt_embeds is provided and prompt is None, skip encoding
+ if prompt_embeds is not None and prompt is None:
+ if self.do_classifier_free_guidance and negative_prompt_embeds is None:
+ raise ValueError(
+ "When `prompt_embeds` is provided without `prompt`, "
+ "`negative_prompt_embeds` must also be provided for classifier-free guidance."
+ )
+ else:
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.in_channels
+
+ # Repeat prompt_embeds for num_images_per_prompt
+ if num_images_per_prompt > 1:
+ prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
+ if self.do_classifier_free_guidance and negative_prompt_embeds:
+ negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
+
+ actual_batch_size = batch_size * num_images_per_prompt
+
+ # Calculate latent dimensions for image_seq_len
+ latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ image_seq_len = (latent_height // 2) * (latent_width // 2)
+
+ # 5. Prepare timesteps
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ self.scheduler.sigma_min = 0.0
+ scheduler_kwargs = {"mu": mu}
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ **scheduler_kwargs,
+ )
+
+ # 6. Adjust timesteps based on strength
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(actual_batch_size)
+
+ # 7. Prepare latents from image
+ latents = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ actual_batch_size,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds[0].dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 8. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0])
+ timestep = (1000 - timestep) / 1000
+ # Normalized time for time-aware config (0 at start, 1 at end)
+ t_norm = timestep[0].item()
+
+ # Handle cfg truncation
+ current_guidance_scale = self.guidance_scale
+ if (
+ self.do_classifier_free_guidance
+ and self._cfg_truncation is not None
+ and float(self._cfg_truncation) <= 1
+ ):
+ if t_norm > self._cfg_truncation:
+ current_guidance_scale = 0.0
+
+ # Run CFG only if configured AND scale is non-zero
+ apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
+
+ if apply_cfg:
+ latents_typed = latents.to(self.transformer.dtype)
+ latent_model_input = latents_typed.repeat(2, 1, 1, 1)
+ prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
+ timestep_model_input = timestep.repeat(2)
+ else:
+ latent_model_input = latents.to(self.transformer.dtype)
+ prompt_embeds_model_input = prompt_embeds
+ timestep_model_input = timestep
+
+ latent_model_input = latent_model_input.unsqueeze(2)
+ latent_model_input_list = list(latent_model_input.unbind(dim=0))
+
+ model_out_list = self.transformer(
+ latent_model_input_list,
+ timestep_model_input,
+ prompt_embeds_model_input,
+ )[0]
+
+ if apply_cfg:
+ # Perform CFG
+ pos_out = model_out_list[:actual_batch_size]
+ neg_out = model_out_list[actual_batch_size:]
+
+ noise_pred = []
+ for j in range(actual_batch_size):
+ pos = pos_out[j].float()
+ neg = neg_out[j].float()
+
+ pred = pos + current_guidance_scale * (pos - neg)
+
+ # Renormalization
+ if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
+ ori_pos_norm = torch.linalg.vector_norm(pos)
+ new_pos_norm = torch.linalg.vector_norm(pred)
+ max_new_norm = ori_pos_norm * float(self._cfg_normalization)
+ if new_pos_norm > max_new_norm:
+ pred = pred * (max_new_norm / new_pos_norm)
+
+ noise_pred.append(pred)
+
+ noise_pred = torch.stack(noise_pred, dim=0)
+ else:
+ noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
+
+ noise_pred = noise_pred.squeeze(2)
+ noise_pred = -noise_pred
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
+ assert latents.dtype == torch.float32
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = latents.to(self.vae.dtype)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ZImagePipelineOutput(images=image)
diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py
index 4c8483a3d6ee..3ca867c12908 100644
--- a/src/diffusers/quantizers/__init__.py
+++ b/src/diffusers/quantizers/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,5 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
+from .pipe_quant_config import PipelineQuantizationConfig
diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py
index ce214ae7bc17..070bcd0b2151 100644
--- a/src/diffusers/quantizers/auto.py
+++ b/src/diffusers/quantizers/auto.py
@@ -21,9 +21,11 @@
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
+from .modelopt import NVIDIAModelOptQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
+ NVIDIAModelOptConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@@ -39,6 +41,7 @@
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
+ "modelopt": NVIDIAModelOptQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -47,6 +50,7 @@
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
+ "modelopt": NVIDIAModelOptConfig,
}
@@ -137,6 +141,9 @@ def merge_quantization_configs(
if isinstance(quantization_config, dict):
quantization_config = cls.from_dict(quantization_config)
+ if isinstance(quantization_config, NVIDIAModelOptConfig):
+ quantization_config.check_model_patching()
+
if warning_msg != "":
warnings.warn(warning_msg)
diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py
index 1c75b5bef933..24fc724b4c88 100644
--- a/src/diffusers/quantizers/base.py
+++ b/src/diffusers/quantizers/base.py
@@ -199,7 +199,7 @@ def postprocess_model(self, model: "ModelMixin", **kwargs):
def dequantize(self, model):
"""
- Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note
+ Potentially dequantize the model to retrieve the original model, with some loss in accuracy / performance. Note
not all quantization schemes support this.
"""
model = self._dequantize(model)
@@ -209,25 +209,37 @@ def dequantize(self, model):
return model
+ def get_cuda_warm_up_factor(self):
+ """
+ The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
+ A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
+ we allocate half the memory of the weights residing in the empty model, etc...
+ """
+ # By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
+ # really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
+ # weight loading)
+ return 4
+
def _dequantize(self, model):
raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)
@abstractmethod
- def _process_model_before_weight_loading(self, model, **kwargs):
- ...
+ def _process_model_before_weight_loading(self, model, **kwargs): ...
@abstractmethod
- def _process_model_after_weight_loading(self, model, **kwargs):
- ...
+ def _process_model_after_weight_loading(self, model, **kwargs): ...
@property
@abstractmethod
- def is_serializable(self):
- ...
+ def is_serializable(self): ...
@property
@abstractmethod
- def is_trainable(self):
- ...
+ def is_trainable(self): ...
+
+ @property
+ def is_compileable(self) -> bool:
+ """Flag indicating whether the quantized model can be compiled"""
+ return False
diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
index 689d8e4256c2..0dfdff019b79 100644
--- a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
+++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
@@ -564,6 +564,10 @@ def is_trainable(self) -> bool:
# Because we're mandating `bitsandbytes` 0.43.3.
return True
+ @property
+ def is_compileable(self) -> bool:
+ return True
+
def _dequantize(self, model):
from .utils import dequantize_and_replace
diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py
index a9771b368a86..429aabb8fae6 100644
--- a/src/diffusers/quantizers/bitsandbytes/utils.py
+++ b/src/diffusers/quantizers/bitsandbytes/utils.py
@@ -49,7 +49,7 @@ def _replace_with_bnb_linear(
"""
Private method that wraps the recursion for module replacement.
- Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
+ Returns the converted model and a boolean that indicates if the conversion has been successful or not.
"""
for name, module in model.named_children():
if current_key_name is None:
@@ -121,8 +121,9 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
References:
* `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at
- Scale](https://arxiv.org/abs/2208.07339)
- * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
+ Scale](https://huggingface.co/papers/2208.07339)
+ * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized
+ LLMs](https://huggingface.co/papers/2305.14314)
Parameters:
model (`torch.nn.Module`):
@@ -139,10 +140,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
models by reducing the precision of the weights and activations, thus making models more efficient in terms
of both storage and computation.
"""
- model, has_been_replaced = _replace_with_bnb_linear(
- model, modules_to_not_convert, current_key_name, quantization_config
- )
+ model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
+ has_been_replaced = any(
+ isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
+ for _, replaced_module in model.named_modules()
+ )
if not has_been_replaced:
logger.warning(
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
@@ -169,9 +172,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc
if cls_name == "Params4bit":
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
- logger.warning_once(
- f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
- )
+ msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
+ if dtype:
+ msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}"
+ output_tensor = output_tensor.to(dtype)
+ logger.warning_once(msg)
return output_tensor
if state.SCB is None:
@@ -219,7 +224,7 @@ def _dequantize_and_replace(
performance drop compared to the original model before quantization - use it only for specific usecases such as
QLoRA adapters merging.
- Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
+ Returns the converted model and a boolean that indicates if the conversion has been successful or not.
"""
quant_method = quantization_config.quantization_method()
@@ -283,16 +288,18 @@ def dequantize_and_replace(
modules_to_not_convert=None,
quantization_config=None,
):
- model, has_been_replaced = _dequantize_and_replace(
+ model, _ = _dequantize_and_replace(
model,
dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
-
+ has_been_replaced = any(
+ isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
+ )
if not has_been_replaced:
logger.warning(
- "For some reason the model has not been properly dequantized. You might see unexpected behavior."
+ "Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
)
return model
diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py
index 6da69c7bd60c..aa5ebf5711a3 100644
--- a/src/diffusers/quantizers/gguf/gguf_quantizer.py
+++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py
@@ -49,7 +49,7 @@ def __init__(self, quantization_config, **kwargs):
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
- "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
+ "Loading GGUF Parameters requires `accelerate` installed in your environment: `pip install 'accelerate>=0.26.0'`"
)
if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
raise ImportError(
@@ -82,7 +82,7 @@ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
if inferred_shape != current_param_shape:
raise ValueError(
- f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
+ f"{param_name} has an expected quantized shape of: {inferred_shape}, but received shape: {loaded_param_shape}"
)
return True
@@ -146,13 +146,22 @@ def is_serializable(self):
def is_trainable(self) -> bool:
return False
+ @property
+ def is_compileable(self) -> bool:
+ return True
+
def _dequantize(self, model):
is_model_on_cpu = model.device.type == "cpu"
if is_model_on_cpu:
logger.info(
- "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
+ "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to accelerator. After dequantization, will move the model back to CPU again to preserve the previous device."
+ )
+ device = (
+ torch.accelerator.current_accelerator()
+ if hasattr(torch, "accelerator")
+ else torch.cuda.current_device()
)
- model.to(torch.cuda.current_device())
+ model.to(device)
model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
if is_model_on_cpu:
diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py
index effc39d8fe97..2fba9986e825 100644
--- a/src/diffusers/quantizers/gguf/utils.py
+++ b/src/diffusers/quantizers/gguf/utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team and City96. All rights reserved.
+# Copyright 2025 The HuggingFace Team and City96. All rights reserved.
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
@@ -12,15 +12,15 @@
# # See the License for the specific language governing permissions and
# # limitations under the License.
-
import inspect
+import os
from contextlib import nullcontext
import gguf
import torch
import torch.nn as nn
-from ...utils import is_accelerate_available
+from ...utils import is_accelerate_available, is_kernels_available
if is_accelerate_available():
@@ -29,6 +29,82 @@
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
+can_use_cuda_kernels = (
+ os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
+ and torch.cuda.is_available()
+ and torch.cuda.get_device_capability()[0] >= 7
+)
+if can_use_cuda_kernels and is_kernels_available():
+ from kernels import get_kernel
+
+ ops = get_kernel("Isotr0py/ggml")
+else:
+ ops = None
+
+UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
+STANDARD_QUANT_TYPES = {
+ gguf.GGMLQuantizationType.Q4_0,
+ gguf.GGMLQuantizationType.Q4_1,
+ gguf.GGMLQuantizationType.Q5_0,
+ gguf.GGMLQuantizationType.Q5_1,
+ gguf.GGMLQuantizationType.Q8_0,
+ gguf.GGMLQuantizationType.Q8_1,
+}
+KQUANT_TYPES = {
+ gguf.GGMLQuantizationType.Q2_K,
+ gguf.GGMLQuantizationType.Q3_K,
+ gguf.GGMLQuantizationType.Q4_K,
+ gguf.GGMLQuantizationType.Q5_K,
+ gguf.GGMLQuantizationType.Q6_K,
+}
+IMATRIX_QUANT_TYPES = {
+ gguf.GGMLQuantizationType.IQ1_M,
+ gguf.GGMLQuantizationType.IQ1_S,
+ gguf.GGMLQuantizationType.IQ2_XXS,
+ gguf.GGMLQuantizationType.IQ2_XS,
+ gguf.GGMLQuantizationType.IQ2_S,
+ gguf.GGMLQuantizationType.IQ3_XXS,
+ gguf.GGMLQuantizationType.IQ3_S,
+ gguf.GGMLQuantizationType.IQ4_XS,
+ gguf.GGMLQuantizationType.IQ4_NL,
+}
+# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
+# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
+# MMQ kernel for I-Matrix quantization.
+DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
+MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
+MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
+
+
+def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
+ # there is no need to call any kernel for fp16/bf16
+ if qweight_type in UNQUANTIZED_TYPES:
+ return x @ qweight.T
+
+ # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
+ # contiguous batching and inefficient with diffusers' batching,
+ # so we disabled it now.
+
+ # elif qweight_type in MMVQ_QUANT_TYPES:
+ # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
+ # elif qweight_type in MMQ_QUANT_TYPES:
+ # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
+
+ # If there is no available MMQ kernel, fallback to dequantize
+ if qweight_type in DEQUANT_TYPES:
+ block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
+ shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
+ weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
+ y = x @ weight.to(x.dtype).T
+ else:
+ # Raise an error if the quantization type is not supported.
+ # Might be useful if llama.cpp adds a new quantization type.
+ # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
+ qweight_type = gguf.GGMLQuantizationType(qweight_type)
+ raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
+ return y.as_tensor()
+
+
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
@@ -353,8 +429,64 @@ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
+# this part from calcuis (gguf.org)
+# more info: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py
+
+
+def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None):
+ kvalues = torch.tensor(
+ [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113],
+ dtype=torch.float32,
+ device=blocks.device,
+ )
+ n_blocks = blocks.shape[0]
+ d, qs = split_block_dims(blocks, 2)
+ d = d.view(torch.float16).to(dtype)
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
+ [0, 4], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, 1, 2, 1))
+ qs = (qs & 15).reshape((n_blocks, -1)).to(torch.int64)
+ kvalues = kvalues.view(1, 1, 16)
+ qs = qs.unsqueeze(-1)
+ qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], 16), 2, qs)
+ qs = qs.squeeze(-1).to(dtype)
+ return d * qs
+
+
+def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None):
+ kvalues = torch.tensor(
+ [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113],
+ dtype=torch.float32,
+ device=blocks.device,
+ )
+ n_blocks = blocks.shape[0]
+ d, scales_h, scales_l, qs = split_block_dims(blocks, 2, 2, QK_K // 64)
+ d = d.view(torch.float16).to(dtype)
+ scales_h = scales_h.view(torch.int16)
+ scales_l = scales_l.reshape((n_blocks, -1, 1)) >> torch.tensor(
+ [0, 4], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, 1, 2))
+ scales_h = scales_h.reshape((n_blocks, 1, -1)) >> torch.tensor(
+ [2 * i for i in range(QK_K // 32)], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, -1, 1))
+ scales_l = scales_l.reshape((n_blocks, -1)) & 0x0F
+ scales_h = scales_h.reshape((n_blocks, -1)) & 0x03
+ scales = (scales_l | (scales_h << 4)) - 32
+ dl = (d * scales.to(dtype)).reshape((n_blocks, -1, 1))
+ shifts_q = torch.tensor([0, 4], device=blocks.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
+ qs = qs.reshape((n_blocks, -1, 1, 16)) >> shifts_q
+ qs = (qs & 15).reshape((n_blocks, -1, 32)).to(torch.int64)
+ kvalues = kvalues.view(1, 1, 1, 16)
+ qs = qs.unsqueeze(-1)
+ qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], qs.shape[2], 16), 3, qs)
+ qs = qs.squeeze(-1).to(dtype)
+ return (dl * qs).reshape(n_blocks, -1)
+
+
GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
dequantize_functions = {
+ gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL,
+ gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS,
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
@@ -400,12 +532,26 @@ def __new__(cls, data, requires_grad=False, quant_type=None):
data = data if data is not None else torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quant_type = quant_type
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
+ self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
return self
def as_tensor(self):
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
+ @staticmethod
+ def _extract_quant_type(args):
+ # When converting from original format checkpoints we often use splits, cats etc on tensors
+ # this method ensures that the returned tensor type from those operations remains GGUFParameter
+ # so that we preserve quant_type information
+ for arg in args:
+ if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
+ return arg[0].quant_type
+ if isinstance(arg, GGUFParameter):
+ return arg.quant_type
+ return None
+
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
@@ -413,22 +559,13 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
result = super().__torch_function__(func, types, args, kwargs)
- # When converting from original format checkpoints we often use splits, cats etc on tensors
- # this method ensures that the returned tensor type from those operations remains GGUFParameter
- # so that we preserve quant_type information
- quant_type = None
- for arg in args:
- if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
- quant_type = arg[0].quant_type
- break
- if isinstance(arg, GGUFParameter):
- quant_type = arg.quant_type
- break
if isinstance(result, torch.Tensor):
+ quant_type = cls._extract_quant_type(args)
return cls(result, quant_type=quant_type)
# Handle tuples and lists
- elif isinstance(result, (tuple, list)):
+ elif type(result) in (list, tuple):
# Preserve the original type (tuple or list)
+ quant_type = cls._extract_quant_type(args)
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
return type(result)(wrapped)
else:
@@ -446,11 +583,24 @@ def __init__(
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
+ self.device = device
+
+ def forward(self, inputs: torch.Tensor):
+ if ops is not None and self.weight.is_cuda and inputs.is_cuda:
+ return self.forward_cuda(inputs)
+ return self.forward_native(inputs)
- def forward(self, inputs):
+ def forward_native(self, inputs: torch.Tensor):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output
+
+ def forward_cuda(self, inputs: torch.Tensor):
+ quant_type = self.weight.quant_type
+ output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
+ if self.bias is not None:
+ output += self.bias.to(self.compute_dtype)
+ return output
diff --git a/src/diffusers/quantizers/modelopt/__init__.py b/src/diffusers/quantizers/modelopt/__init__.py
new file mode 100644
index 000000000000..ae0951cb30d1
--- /dev/null
+++ b/src/diffusers/quantizers/modelopt/__init__.py
@@ -0,0 +1 @@
+from .modelopt_quantizer import NVIDIAModelOptQuantizer
diff --git a/src/diffusers/quantizers/modelopt/modelopt_quantizer.py b/src/diffusers/quantizers/modelopt/modelopt_quantizer.py
new file mode 100644
index 000000000000..7312036f52d0
--- /dev/null
+++ b/src/diffusers/quantizers/modelopt/modelopt_quantizer.py
@@ -0,0 +1,190 @@
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from ...utils import (
+ get_module_from_name,
+ is_accelerate_available,
+ is_nvidia_modelopt_available,
+ is_torch_available,
+ logging,
+)
+from ..base import DiffusersQuantizer
+
+
+if TYPE_CHECKING:
+ from ...models.modeling_utils import ModelMixin
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+if is_accelerate_available():
+ from accelerate.utils import set_module_tensor_to_device
+
+
+logger = logging.get_logger(__name__)
+
+
+class NVIDIAModelOptQuantizer(DiffusersQuantizer):
+ r"""
+ Diffusers Quantizer for Nvidia-Model Optimizer
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_calibration = False
+ required_packages = ["nvidia_modelopt"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ def validate_environment(self, *args, **kwargs):
+ if not is_nvidia_modelopt_available():
+ raise ImportError(
+ "Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)"
+ )
+
+ self.offload = False
+
+ device_map = kwargs.get("device_map", None)
+ if isinstance(device_map, dict):
+ if "cpu" in device_map.values() or "disk" in device_map.values():
+ if self.pre_quantized:
+ raise ValueError(
+ "You are attempting to perform cpu/disk offload with a pre-quantized modelopt model "
+ "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
+ )
+ else:
+ self.offload = True
+
+ def check_if_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ):
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ from modelopt.torch.quantization.utils import is_quantized
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if self.pre_quantized:
+ return True
+ elif is_quantized(module) and "weight" in tensor_name:
+ return True
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ *args,
+ **kwargs,
+ ):
+ """
+ Create the quantized parameter by calling .calibrate() after setting it to the module.
+ """
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ import modelopt.torch.quantization as mtq
+
+ dtype = kwargs.get("dtype", torch.float32)
+ module, tensor_name = get_module_from_name(model, param_name)
+ if self.pre_quantized:
+ module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
+ else:
+ set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
+ mtq.calibrate(
+ module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop
+ )
+ mtq.compress(module)
+ module.weight.requires_grad = False
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if self.quantization_config.quant_type == "FP8":
+ target_dtype = torch.float8_e4m3fn
+ return target_dtype
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
+ if torch_dtype is None:
+ logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
+ torch_dtype = torch.float32
+ return torch_dtype
+
+ def get_conv_param_names(self, model: "ModelMixin") -> List[str]:
+ """
+ Get parameter names for all convolutional layers in a HuggingFace ModelMixin. Includes Conv1d/2d/3d and
+ ConvTranspose1d/2d/3d.
+ """
+ conv_types = (
+ nn.Conv1d,
+ nn.Conv2d,
+ nn.Conv3d,
+ nn.ConvTranspose1d,
+ nn.ConvTranspose2d,
+ nn.ConvTranspose3d,
+ )
+
+ conv_param_names = []
+ for name, module in model.named_modules():
+ if isinstance(module, conv_types):
+ for param_name, _ in module.named_parameters(recurse=False):
+ conv_param_names.append(f"{name}.{param_name}")
+
+ return conv_param_names
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ import modelopt.torch.opt as mto
+
+ if self.pre_quantized:
+ return
+
+ modules_to_not_convert = self.quantization_config.modules_to_not_convert
+
+ if modules_to_not_convert is None:
+ modules_to_not_convert = []
+ if isinstance(modules_to_not_convert, str):
+ modules_to_not_convert = [modules_to_not_convert]
+ modules_to_not_convert.extend(keep_in_fp32_modules)
+ if self.quantization_config.disable_conv_quantization:
+ modules_to_not_convert.extend(self.get_conv_param_names(model))
+
+ for module in modules_to_not_convert:
+ self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False}
+ self.quantization_config.modules_to_not_convert = modules_to_not_convert
+ mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)])
+ model.config.quantization_config = self.quantization_config
+
+ def _process_model_after_weight_loading(self, model, **kwargs):
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ from modelopt.torch.opt import ModeloptStateManager
+
+ if self.pre_quantized:
+ return model
+
+ for _, m in model.named_modules():
+ if hasattr(m, ModeloptStateManager._state_key) and m is not model:
+ ModeloptStateManager.remove_state(m)
+
+ return model
+
+ @property
+ def is_trainable(self):
+ return True
+
+ @property
+ def is_serializable(self):
+ self.quantization_config.check_model_patching(operation="saving")
+ return True
diff --git a/src/diffusers/quantizers/pipe_quant_config.py b/src/diffusers/quantizers/pipe_quant_config.py
new file mode 100644
index 000000000000..f75a337341a9
--- /dev/null
+++ b/src/diffusers/quantizers/pipe_quant_config.py
@@ -0,0 +1,205 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Dict, List, Optional, Union
+
+from ..utils import is_transformers_available, logging
+from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
+
+
+try:
+ from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
+except ImportError:
+
+ class TransformersQuantConfigMixin:
+ pass
+
+
+logger = logging.get_logger(__name__)
+
+
+class PipelineQuantizationConfig:
+ """
+ Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
+
+ Args:
+ quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
+ is available to both `diffusers` and `transformers`.
+ quant_kwargs (`dict`): Params to initialize the quantization backend class.
+ components_to_quantize (`list`): Components of a pipeline to be quantized.
+ quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
+ components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
+ and `components_to_quantize`.
+ """
+
+ def __init__(
+ self,
+ quant_backend: str = None,
+ quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
+ components_to_quantize: Optional[Union[List[str], str]] = None,
+ quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
+ ):
+ self.quant_backend = quant_backend
+ # Initialize kwargs to be {} to set to the defaults.
+ self.quant_kwargs = quant_kwargs or {}
+ if components_to_quantize:
+ if isinstance(components_to_quantize, str):
+ components_to_quantize = [components_to_quantize]
+ self.components_to_quantize = components_to_quantize
+ self.quant_mapping = quant_mapping
+ self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
+ self.post_init()
+
+ def post_init(self):
+ quant_mapping = self.quant_mapping
+ self.is_granular = True if quant_mapping is not None else False
+
+ self._validate_init_args()
+
+ def _validate_init_args(self):
+ if self.quant_backend and self.quant_mapping:
+ raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
+
+ if not self.quant_mapping and not self.quant_backend:
+ raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
+
+ if not self.quant_kwargs and not self.quant_mapping:
+ raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
+
+ if self.quant_backend is not None:
+ self._validate_init_kwargs_in_backends()
+
+ if self.quant_mapping is not None:
+ self._validate_quant_mapping_args()
+
+ def _validate_init_kwargs_in_backends(self):
+ quant_backend = self.quant_backend
+
+ self._check_backend_availability(quant_backend)
+
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
+
+ if quant_config_mapping_transformers is not None:
+ init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
+ init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
+ else:
+ init_kwargs_transformers = None
+
+ init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
+ init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
+
+ if init_kwargs_transformers != init_kwargs_diffusers:
+ raise ValueError(
+ "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
+ f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
+ "this mapping would look like."
+ )
+
+ def _validate_quant_mapping_args(self):
+ quant_mapping = self.quant_mapping
+ transformers_map, diffusers_map = self._get_quant_config_list()
+
+ available_transformers = list(transformers_map.values()) if transformers_map else None
+ available_diffusers = list(diffusers_map.values())
+
+ for module_name, config in quant_mapping.items():
+ if any(isinstance(config, cfg) for cfg in available_diffusers):
+ continue
+
+ if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
+ continue
+
+ if available_transformers:
+ raise ValueError(
+ f"Provided config for module_name={module_name} could not be found. "
+ f"Available diffusers configs: {available_diffusers}; "
+ f"Available transformers configs: {available_transformers}."
+ )
+ else:
+ raise ValueError(
+ f"Provided config for module_name={module_name} could not be found. "
+ f"Available diffusers configs: {available_diffusers}."
+ )
+
+ def _check_backend_availability(self, quant_backend: str):
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
+
+ available_backends_transformers = (
+ list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
+ )
+ available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
+
+ if (
+ available_backends_transformers and quant_backend not in available_backends_transformers
+ ) or quant_backend not in quant_config_mapping_diffusers:
+ error_message = f"Provided quant_backend={quant_backend} was not found."
+ if available_backends_transformers:
+ error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
+ error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
+ raise ValueError(error_message)
+
+ def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
+
+ quant_mapping = self.quant_mapping
+ components_to_quantize = self.components_to_quantize
+
+ # Granular case
+ if self.is_granular and module_name in quant_mapping:
+ logger.debug(f"Initializing quantization config class for {module_name}.")
+ config = quant_mapping[module_name]
+ self.config_mapping.update({module_name: config})
+ return config
+
+ # Global config case
+ else:
+ should_quantize = False
+ # Only quantize the modules requested for.
+ if components_to_quantize and module_name in components_to_quantize:
+ should_quantize = True
+ # No specification for `components_to_quantize` means all modules should be quantized.
+ elif not self.is_granular and not components_to_quantize:
+ should_quantize = True
+
+ if should_quantize:
+ logger.debug(f"Initializing quantization config class for {module_name}.")
+ mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
+ quant_config_cls = mapping_to_use[self.quant_backend]
+ quant_kwargs = self.quant_kwargs
+ quant_obj = quant_config_cls(**quant_kwargs)
+ self.config_mapping.update({module_name: quant_obj})
+ return quant_obj
+
+ # Fallback: no applicable configuration found.
+ return None
+
+ def _get_quant_config_list(self):
+ if is_transformers_available():
+ from transformers.quantizers.auto import (
+ AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
+ )
+ else:
+ quant_config_mapping_transformers = None
+
+ from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
+
+ return quant_config_mapping_transformers, quant_config_mapping_diffusers
+
+ def __repr__(self):
+ out = ""
+ config_mapping = dict(sorted(self.config_mapping.copy().items()))
+ for module_name, config in config_mapping.items():
+ out += f"{module_name} {config}"
+ return out
diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py
index 0bc433be0ff3..5dd8f56717df 100644
--- a/src/diffusers/quantizers/quantization_config.py
+++ b/src/diffusers/quantizers/quantization_config.py
@@ -21,18 +21,20 @@
"""
import copy
+import dataclasses
import importlib.metadata
import inspect
import json
import os
-from dataclasses import dataclass
+import warnings
+from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import partial
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
from packaging import version
-from ..utils import is_torch_available, is_torchao_available, logging
+from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -46,6 +48,7 @@ class QuantizationMethod(str, Enum):
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
+ MODELOPT = "modelopt"
if is_torchao_available():
@@ -75,7 +78,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
- return_unused_kwargs (`bool`,*optional*, defaults to `False`):
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
`PreTrainedModel`.
kwargs (`Dict[str, Any]`):
@@ -179,7 +182,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `bitsandbytes`.
- This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
+ This replaces `load_in_8bit` or `load_in_4bit` therefore both options are mutually exclusive.
Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
then more arguments will be added to this class.
@@ -192,10 +195,10 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
`bitsandbytes`.
llm_int8_threshold (`float`, *optional*, defaults to 6.0):
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
- Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
- that is above this threshold will be considered an outlier and the operation on those values will be done
- in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
- there are some exceptional systematic outliers that are very differently distributed for large models.
+ Multiplication for Transformers at Scale` paper: https://huggingface.co/papers/2208.07339 Any hidden states
+ value that is above this threshold will be considered an outlier and the operation on those values will be
+ done in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5],
+ but there are some exceptional systematic outliers that are very differently distributed for large models.
These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
but a lower threshold might be needed for more unstable models (small models, fine-tuning).
@@ -268,7 +271,14 @@ def __init__(
if bnb_4bit_quant_storage is None:
self.bnb_4bit_quant_storage = torch.uint8
elif isinstance(bnb_4bit_quant_storage, str):
- if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
+ if bnb_4bit_quant_storage not in [
+ "float16",
+ "float32",
+ "int8",
+ "uint8",
+ "float64",
+ "bfloat16",
+ ]:
raise ValueError(
"`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
)
@@ -434,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
- quant_type (`str`):
+ quant_type (Union[`str`, AOBaseConfig]):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
@@ -456,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
+ - An AOBaseConfig instance: for more advanced configuration options.
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
@@ -469,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig
+ # AOBaseConfig-based configuration
+ from torchao.quantization import Int8WeightOnlyConfig
+
+ quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
+
+ # String-based config
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
@@ -479,7 +496,12 @@ class TorchAoConfig(QuantizationConfigMixin):
```
"""
- def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None:
+ def __init__(
+ self,
+ quant_type: Union[str, "AOBaseConfig"], # noqa: F821
+ modules_to_not_convert: Optional[List[str]] = None,
+ **kwargs,
+ ) -> None:
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
@@ -490,34 +512,103 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]]
else:
self.quant_type_kwargs = kwargs
- TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
- if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
- is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
- if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
+ self.post_init()
+
+ def post_init(self):
+ if not isinstance(self.quant_type, str):
+ if is_torchao_version("<=", "0.9.0"):
raise ValueError(
- f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
- f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
+ f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
+ f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
)
- raise ValueError(
- f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
- f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
- )
+ from torchao.quantization.quant_api import AOBaseConfig
- method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
- signature = inspect.signature(method)
- all_kwargs = {
- param.name
- for param in signature.parameters.values()
- if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
- }
- unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
+ if not isinstance(self.quant_type, AOBaseConfig):
+ raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
- if len(unsupported_kwargs) > 0:
- raise ValueError(
- f'The quantization method "{quant_type}" does not support the following keyword arguments: '
- f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
- )
+ elif isinstance(self.quant_type, str):
+ TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
+
+ if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
+ is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
+ if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
+ f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
+ )
+
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
+ f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
+ signature = inspect.signature(method)
+ all_kwargs = {
+ param.name
+ for param in signature.parameters.values()
+ if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
+ }
+ unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
+
+ if len(unsupported_kwargs) > 0:
+ raise ValueError(
+ f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
+ f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
+ )
+
+ def to_dict(self):
+ """Convert configuration to a dictionary."""
+ d = super().to_dict()
+
+ if isinstance(self.quant_type, str):
+ # Handle layout serialization if present
+ if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
+ if is_dataclass(d["quant_type_kwargs"]["layout"]):
+ d["quant_type_kwargs"]["layout"] = [
+ d["quant_type_kwargs"]["layout"].__class__.__name__,
+ dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
+ ]
+ if isinstance(d["quant_type_kwargs"]["layout"], list):
+ assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
+ assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
+ assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
+ else:
+ raise ValueError("layout must be a list")
+ else:
+ # Handle AOBaseConfig serialization
+ from torchao.core.config import config_to_dict
+
+ # For now we assume there is 1 config per Transformer, however in the future
+ # We may want to support a config per fqn.
+ d["quant_type"] = {"default": config_to_dict(self.quant_type)}
+
+ return d
+
+ @classmethod
+ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
+ """Create configuration from a dictionary."""
+ if not is_torchao_version(">", "0.9.0"):
+ raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
+ config_dict = config_dict.copy()
+ quant_type = config_dict.pop("quant_type")
+
+ if isinstance(quant_type, str):
+ return cls(quant_type=quant_type, **config_dict)
+ # Check if we only have one key which is "default"
+ # In the future we may update this
+ assert len(quant_type) == 1 and "default" in quant_type, (
+ "Expected only one key 'default' in quant_type dictionary"
+ )
+ quant_type = quant_type["default"]
+
+ # Deserialize quant_type if needed
+ from torchao.core.config import config_from_dict
+
+ quant_type = config_from_dict(quant_type)
+
+ return cls(quant_type=quant_type, **config_dict)
@classmethod
def _get_torchao_quant_type_to_method(cls):
@@ -645,7 +736,7 @@ def generate_fpx_quantization_types(bits: int):
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
- if cls._is_cuda_capability_atleast_8_9():
+ if cls._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
return QUANTIZATION_TYPES
@@ -655,18 +746,50 @@ def generate_fpx_quantization_types(bits: int):
)
@staticmethod
- def _is_cuda_capability_atleast_8_9() -> bool:
- if not torch.cuda.is_available():
- raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
-
- major, minor = torch.cuda.get_device_capability()
- if major == 8:
- return minor >= 9
- return major >= 9
+ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
+ if torch.cuda.is_available():
+ major, minor = torch.cuda.get_device_capability()
+ if major == 8:
+ return minor >= 9
+ return major >= 9
+ elif torch.xpu.is_available():
+ return True
+ else:
+ raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self):
- TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
- return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
+ """Create the appropriate quantization method based on configuration."""
+ if not isinstance(self.quant_type, str):
+ return self.quant_type
+ else:
+ methods = self._get_torchao_quant_type_to_method()
+ quant_type_kwargs = self.quant_type_kwargs.copy()
+ if (
+ not torch.cuda.is_available()
+ and is_torchao_available()
+ and self.quant_type == "int4_weight_only"
+ and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
+ and quant_type_kwargs.get("layout", None) is None
+ ):
+ if torch.xpu.is_available():
+ if version.parse(importlib.metadata.version("torchao")) >= version.parse(
+ "0.11.0"
+ ) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
+ from torchao.dtypes import Int4XPULayout
+ from torchao.quantization.quant_primitives import ZeroPointDomain
+
+ quant_type_kwargs["layout"] = Int4XPULayout()
+ quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
+ else:
+ raise ValueError(
+ "TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
+ )
+ else:
+ from torchao.dtypes import Int4CPULayout
+
+ quant_type_kwargs["layout"] = Int4CPULayout()
+
+ return methods[self.quant_type](**quant_type_kwargs)
def __repr__(self):
r"""
@@ -722,3 +845,194 @@ def post_init(self):
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
+
+
+@dataclass
+class NVIDIAModelOptConfig(QuantizationConfigMixin):
+ """This is a config class to use nvidia modelopt for quantization.
+
+ Args:
+ quant_type (`str`):
+ The type of quantization we want to use, following is how to use:
+ **weightquant_activationquant ==> FP8_FP8** In the above example we have use FP8 for both weight and
+ activation quantization. Following are the all the options:
+ - FP8
+ - INT8
+ - INT4
+ - NF4
+ - NVFP4
+ modules_to_not_convert (`List[str]`, *optional*, default to `None`):
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have some
+ weight_only (`bool`, *optional*, default to `False`):
+ If set to `True`, the quantization will be applied only to the weights of the model.
+ channel_quantize (`int`, *optional*, default to `None`):
+ The channel quantization axis, useful for quantizing models across different axes.
+ block_quantize (`int`, *optional*, default to `None`):
+ The block size, useful to further quantize each channel/axes into blocks.
+ scale_channel_quantize (`int`, *optional*, default to `None`):
+ The scale channel quantization axis, useful for quantizing calculated scale across different axes.
+ scale_block_quantize (`int`, *optional*, default to `None`):
+ The scale block size, useful for quantizing each scale channel/axes into blocks.
+ algorithm (`str`, *optional*, default to `"max"`):
+ The algorithm to use for quantization, currently only supports `"max"`.
+ forward_loop (`Callable`, *optional*, default to `None`):
+ The forward loop function to use for calibration during quantization.
+ modelopt_config (`dict`, *optional*, default to `None`):
+ The modelopt config, useful for passing custom configs to modelopt.
+ disable_conv_quantization (`bool`, *optional*, default to `False`):
+ If set to `True`, the quantization will be disabled for convolutional layers.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional parameters which are to be used for calibration.
+ """
+
+ quanttype_to_numbits = {
+ "FP8": (4, 3),
+ "INT8": 8,
+ "INT4": 4,
+ "NF4": 4,
+ "NVFP4": (2, 1),
+ }
+ quanttype_to_scalingbits = {
+ "NF4": 8,
+ "NVFP4": (4, 3),
+ }
+
+ def __init__(
+ self,
+ quant_type: str,
+ modules_to_not_convert: Optional[List[str]] = None,
+ weight_only: bool = True,
+ channel_quantize: Optional[int] = None,
+ block_quantize: Optional[int] = None,
+ scale_channel_quantize: Optional[int] = None,
+ scale_block_quantize: Optional[int] = None,
+ algorithm: str = "max",
+ forward_loop: Optional[Callable] = None,
+ modelopt_config: Optional[dict] = None,
+ disable_conv_quantization: bool = False,
+ **kwargs,
+ ) -> None:
+ self.quant_method = QuantizationMethod.MODELOPT
+ self._normalize_quant_type(quant_type)
+ self.modules_to_not_convert = modules_to_not_convert
+ self.weight_only = weight_only
+ self.channel_quantize = channel_quantize
+ self.block_quantize = block_quantize
+ self.calib_cfg = {
+ "method": algorithm,
+ # add more options here if needed
+ }
+ self.forward_loop = forward_loop
+ self.scale_channel_quantize = scale_channel_quantize
+ self.scale_block_quantize = scale_block_quantize
+ self.modelopt_config = self.get_config_from_quant_type() if not modelopt_config else modelopt_config
+ self.disable_conv_quantization = disable_conv_quantization
+
+ def check_model_patching(self, operation: str = "loading"):
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ from modelopt.torch.opt.plugins.huggingface import _PATCHED_CLASSES
+
+ if len(_PATCHED_CLASSES) == 0:
+ warning_msg = (
+ f"Not {operation} weights in modelopt format. This might cause unreliable behavior."
+ "Please make sure to run the following code before loading/saving model weights:\n\n"
+ " from modelopt.torch.opt import enable_huggingface_checkpointing\n"
+ " enable_huggingface_checkpointing()\n"
+ )
+ warnings.warn(warning_msg)
+
+ def _normalize_quant_type(self, quant_type: str) -> str:
+ """
+ Validates and normalizes the quantization type string.
+
+ Splits the quant_type into weight and activation components, verifies them against supported types, and
+ replaces unsupported values with safe defaults.
+
+ Args:
+ quant_type (str): The input quantization type string (e.g., 'FP8_INT8').
+
+ Returns:
+ str: A valid quantization type string (e.g., 'FP8_INT8' or 'FP8').
+ """
+ parts = quant_type.split("_")
+ w_type = parts[0]
+ act_type = parts[1] if len(parts) > 1 else None
+ if len(parts) > 2:
+ logger.warning(f"Quantization type {quant_type} is not supported. Picking FP8_INT8 as default")
+ w_type = "FP8"
+ act_type = None
+ else:
+ if w_type not in NVIDIAModelOptConfig.quanttype_to_numbits:
+ logger.warning(f"Weight Quantization type {w_type} is not supported. Picking FP8 as default")
+ w_type = "FP8"
+ if act_type is not None and act_type not in NVIDIAModelOptConfig.quanttype_to_numbits:
+ logger.warning(f"Activation Quantization type {act_type} is not supported. Picking INT8 as default")
+ act_type = None
+ self.quant_type = w_type + ("_" + act_type if act_type is not None else "")
+
+ def get_config_from_quant_type(self) -> Dict[str, Any]:
+ """
+ Get the config from the quantization type.
+ """
+ import modelopt.torch.quantization as mtq
+
+ BASE_CONFIG = {
+ "quant_cfg": {
+ "*weight_quantizer": {"fake_quant": False},
+ "*input_quantizer": {},
+ "*output_quantizer": {"enable": False},
+ "*q_bmm_quantizer": {},
+ "*k_bmm_quantizer": {},
+ "*v_bmm_quantizer": {},
+ "*softmax_quantizer": {},
+ **mtq.config._default_disabled_quantizer_cfg,
+ },
+ "algorithm": self.calib_cfg,
+ }
+
+ quant_cfg = BASE_CONFIG["quant_cfg"]
+ if self.weight_only:
+ for k in quant_cfg:
+ if "*weight_quantizer" not in k and not quant_cfg[k]:
+ quant_cfg[k]["enable"] = False
+
+ parts = self.quant_type.split("_")
+ w_type = parts[0]
+ act_type = parts[1].replace("A", "") if len(parts) > 1 else None
+ for k in quant_cfg:
+ if k not in mtq.config._default_disabled_quantizer_cfg and "enable" not in quant_cfg[k]:
+ if k == "*input_quantizer":
+ if act_type is not None:
+ quant_cfg[k]["num_bits"] = NVIDIAModelOptConfig.quanttype_to_numbits[act_type]
+ continue
+ quant_cfg[k]["num_bits"] = NVIDIAModelOptConfig.quanttype_to_numbits[w_type]
+
+ if self.block_quantize is not None and self.channel_quantize is not None:
+ quant_cfg["*weight_quantizer"]["block_sizes"] = {self.channel_quantize: self.block_quantize}
+ quant_cfg["*input_quantizer"]["block_sizes"] = {
+ self.channel_quantize: self.block_quantize,
+ "type": "dynamic",
+ }
+ elif self.channel_quantize is not None:
+ quant_cfg["*weight_quantizer"]["axis"] = self.channel_quantize
+ quant_cfg["*input_quantizer"]["axis"] = self.channel_quantize
+ quant_cfg["*input_quantizer"]["type"] = "dynamic"
+
+ # Only fixed scaling sizes are supported for now in modelopt
+ if self.scale_channel_quantize is not None and self.scale_block_quantize is not None:
+ if w_type in NVIDIAModelOptConfig.quanttype_to_scalingbits:
+ quant_cfg["*weight_quantizer"]["block_sizes"].update(
+ {
+ "scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[w_type],
+ "scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize},
+ }
+ )
+ if act_type and act_type in NVIDIAModelOptConfig.quanttype_to_scalingbits:
+ quant_cfg["*input_quantizer"]["block_sizes"].update(
+ {
+ "scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[act_type],
+ "scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize},
+ }
+ )
+
+ return BASE_CONFIG
diff --git a/src/diffusers/quantizers/quanto/quanto_quantizer.py b/src/diffusers/quantizers/quanto/quanto_quantizer.py
index 0120163804c9..c5f71f816fc3 100644
--- a/src/diffusers/quantizers/quanto/quanto_quantizer.py
+++ b/src/diffusers/quantizers/quanto/quanto_quantizer.py
@@ -175,3 +175,7 @@ def is_trainable(self):
@property
def is_serializable(self):
return True
+
+ @property
+ def is_compileable(self) -> bool:
+ return True
diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py
index f9fb217ed6bd..2334c7af8630 100644
--- a/src/diffusers/quantizers/torchao/torchao_quantizer.py
+++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py
@@ -18,8 +18,10 @@
"""
import importlib
+import re
import types
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from fnmatch import fnmatch
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from packaging import version
@@ -106,6 +108,21 @@ def _update_torch_safe_globals():
_update_torch_safe_globals()
+def fuzzy_match_size(config_name: str) -> Optional[str]:
+ """
+ Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise
+ None.
+ """
+ config_name = config_name.lower()
+
+ str_match = re.search(r"(\d)weight", config_name)
+
+ if str_match:
+ return str_match.group(1)
+
+ return None
+
+
logger = logging.get_logger(__name__)
@@ -175,8 +192,7 @@ def validate_environment(self, *args, **kwargs):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
-
- if quant_type.startswith("int") or quant_type.startswith("uint"):
+ if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
@@ -196,24 +212,44 @@ def update_torch_dtype(self, torch_dtype):
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
-
- if quant_type.startswith("int8") or quant_type.startswith("int4"):
- # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
- return torch.int8
- elif quant_type == "uintx_weight_only":
- return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
- elif quant_type.startswith("uint"):
- return {
- 1: torch.uint1,
- 2: torch.uint2,
- 3: torch.uint3,
- 4: torch.uint4,
- 5: torch.uint5,
- 6: torch.uint6,
- 7: torch.uint7,
- }[int(quant_type[4])]
- elif quant_type.startswith("float") or quant_type.startswith("fp"):
- return torch.bfloat16
+ from accelerate.utils import CustomDtype
+
+ if isinstance(quant_type, str):
+ if quant_type.startswith("int8"):
+ # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
+ return torch.int8
+ elif quant_type.startswith("int4"):
+ return CustomDtype.INT4
+ elif quant_type == "uintx_weight_only":
+ return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
+ elif quant_type.startswith("uint"):
+ return {
+ 1: torch.uint1,
+ 2: torch.uint2,
+ 3: torch.uint3,
+ 4: torch.uint4,
+ 5: torch.uint5,
+ 6: torch.uint6,
+ 7: torch.uint7,
+ }[int(quant_type[4])]
+ elif quant_type.startswith("float") or quant_type.startswith("fp"):
+ return torch.bfloat16
+
+ elif is_torchao_version(">", "0.9.0"):
+ from torchao.core.config import AOBaseConfig
+
+ quant_type = self.quantization_config.quant_type
+ if isinstance(quant_type, AOBaseConfig):
+ # Extract size digit using fuzzy match on the class name
+ config_name = quant_type.__class__.__name__
+ size_digit = fuzzy_match_size(config_name)
+
+ # Map the extracted digit to appropriate dtype
+ if size_digit == "4":
+ return CustomDtype.INT4
+ else:
+ # Default to int8
+ return torch.int8
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
@@ -262,7 +298,7 @@ def create_quantized_param(
**kwargs,
):
r"""
- Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
+ Each nn.Linear layer that needs to be quantized is processed here. First, we set the value the weight tensor,
then we move it to the target device. Finally, we quantize the module.
"""
module, tensor_name = get_module_from_name(model, param_name)
@@ -278,6 +314,46 @@ def create_quantized_param(
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
+ def get_cuda_warm_up_factor(self):
+ """
+ This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
+ - A factor of 2 means we pre-allocate the full memory footprint of the model.
+ - A factor of 4 means we pre-allocate half of that, and so on
+
+ However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
+ the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
+ quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
+ torch_dtype not the actual bit-width of the quantized data.
+
+ To correct for this:
+ - Use a division factor of 8 for int4 weights
+ - Use a division factor of 4 for int8 weights
+ """
+ # Original mapping for non-AOBaseConfig types
+ # For the uint types, this is a best guess. Once these types become more used
+ # we can look into their nuances.
+ if is_torchao_version(">", "0.9.0"):
+ from torchao.core.config import AOBaseConfig
+
+ quant_type = self.quantization_config.quant_type
+ # For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
+ if isinstance(quant_type, AOBaseConfig):
+ # Extract size digit using fuzzy match on the class name
+ config_name = quant_type.__class__.__name__
+ size_digit = fuzzy_match_size(config_name)
+
+ if size_digit == "4":
+ return 8
+ else:
+ return 4
+
+ map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
+ quant_type = self.quantization_config.quant_type
+ for pattern, target_dtype in map_to_target_dtype.items():
+ if fnmatch(quant_type, pattern):
+ return target_dtype
+ raise ValueError(f"Unsupported quant_type: {quant_type!r}")
+
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
@@ -335,3 +411,7 @@ def is_serializable(self, safe_serialization=None):
@property
def is_trainable(self):
return self.quantization_config.quant_type.startswith("int8")
+
+ @property
+ def is_compileable(self) -> bool:
+ return True
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 05cd21cd0034..29052c1ba0cb 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -60,6 +60,7 @@
_import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"]
+ _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"]
_import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
@@ -161,6 +162,7 @@
from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
+ from .scheduling_flow_match_lcm import FlowMatchLCMScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
diff --git a/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py b/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
index f5f9bd256c2e..9206ee80a6b6 100644
--- a/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
+++ b/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
@@ -1,4 +1,4 @@
-# Copyright 2024 NVIDIA and The HuggingFace Team. All rights reserved.
+# Copyright 2025 NVIDIA and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -53,12 +53,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
-
-
- For more details on the parameters, see [Appendix E](https://arxiv.org/abs/2206.00364). The grid search values used
- to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of the paper.
-
-
+ > [!TIP] > For more details on the parameters, see [Appendix E](https://huggingface.co/papers/2206.00364). The grid
+ search > values used to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in
+ Table 5 of > the paper.
Args:
sigma_min (`float`, defaults to 0.02):
diff --git a/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py b/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py
index 09b02cadc400..5088bdb49761 100644
--- a/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py
+++ b/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/schedulers/scheduling_amused.py b/src/diffusers/schedulers/scheduling_amused.py
index 238b8d869171..a0b8fbc862b0 100644
--- a/src/diffusers/schedulers/scheduling_amused.py
+++ b/src/diffusers/schedulers/scheduling_amused.py
@@ -1,6 +1,6 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import torch
@@ -9,13 +9,48 @@
from .scheduling_utils import SchedulerMixin
-def gumbel_noise(t, generator=None):
+def gumbel_noise(t: torch.Tensor, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+ """
+ Generate Gumbel noise for sampling.
+
+ Args:
+ t (`torch.Tensor`):
+ Input tensor to match the shape and dtype of the output noise.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for reproducible sampling.
+
+ Returns:
+ `torch.Tensor`:
+ Gumbel-distributed noise with the same shape, dtype, and device as the input tensor.
+ """
device = generator.device if generator is not None else t.device
noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
-def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
+def mask_by_random_topk(
+ mask_len: torch.Tensor,
+ probs: torch.Tensor,
+ temperature: float = 1.0,
+ generator: Optional[torch.Generator] = None,
+) -> torch.Tensor:
+ """
+ Mask tokens by selecting the top-k lowest confidence scores with temperature-based randomness.
+
+ Args:
+ mask_len (`torch.Tensor`):
+ Number of tokens to mask per sample in the batch.
+ probs (`torch.Tensor`):
+ Probability scores for each token.
+ temperature (`float`, *optional*, defaults to 1.0):
+ Temperature parameter for controlling randomness in the masking process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for reproducible sampling.
+
+ Returns:
+ `torch.Tensor`:
+ Boolean mask indicating which tokens should be masked.
+ """
confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
sorted_confidence = torch.sort(confidence, dim=-1).values
cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
@@ -29,28 +64,46 @@ class AmusedSchedulerOutput(BaseOutput):
Output class for the scheduler's `step` function output.
Args:
- prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
- Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
- denoising loop.
- pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
- The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
- `pred_original_sample` can be used to preview progress or for guidance.
+ prev_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`):
+ Computed sample `(x_{t-1})` of previous timestep with token IDs. `prev_sample` should be used as next model
+ input in the denoising loop.
+ pred_original_sample (`torch.LongTensor` of shape `(batch_size, height, width)` or `(batch_size, sequence_length)`, *optional*):
+ The predicted fully denoised sample `(x_{0})` with token IDs based on the model output from the current
+ timestep. `pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.Tensor
- pred_original_sample: torch.Tensor = None
+ pred_original_sample: Optional[torch.Tensor] = None
class AmusedScheduler(SchedulerMixin, ConfigMixin):
+ """
+ A scheduler for masked token generation as used in [`AmusedPipeline`].
+
+ This scheduler iteratively unmasks tokens based on their confidence scores, following either a cosine or linear
+ schedule. Unlike traditional diffusion schedulers that work with continuous pixel values, this scheduler operates
+ on discrete token IDs, making it suitable for autoregressive and non-autoregressive masked token generation models.
+
+ This scheduler inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the
+ generic methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ mask_token_id (`int`):
+ The token ID used to represent masked tokens in the sequence.
+ masking_schedule (`Literal["cosine", "linear"]`, *optional*, defaults to `"cosine"`):
+ The schedule type for determining the mask ratio at each timestep. Can be either `"cosine"` or `"linear"`.
+ """
+
order = 1
- temperatures: torch.Tensor
+ temperatures: Optional[torch.Tensor]
+ timesteps: Optional[torch.Tensor]
@register_to_config
def __init__(
self,
mask_token_id: int,
- masking_schedule: str = "cosine",
+ masking_schedule: Literal["cosine", "linear"] = "cosine",
):
self.temperatures = None
self.timesteps = None
@@ -58,9 +111,23 @@ def __init__(
def set_timesteps(
self,
num_inference_steps: int,
- temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
- device: Union[str, torch.device] = None,
- ):
+ temperature: Union[float, Tuple[float, float], List[float]] = (2, 0),
+ device: Optional[Union[str, torch.device]] = None,
+ ) -> None:
+ """
+ Set the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ temperature (`Union[float, Tuple[float, float], List[float]]`, *optional*, defaults to `(2, 0)`):
+ Temperature parameter(s) for controlling the randomness of sampling. If a tuple or list is provided,
+ temperatures will be linearly interpolated between the first and second values across all timesteps. If
+ a single value is provided, temperatures will be linearly interpolated from that value to 0.01.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps and temperatures should be moved to. If `None`, the timesteps are not
+ moved.
+ """
self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
if isinstance(temperature, (tuple, list)):
@@ -71,12 +138,38 @@ def set_timesteps(
def step(
self,
model_output: torch.Tensor,
- timestep: torch.long,
+ timestep: int,
sample: torch.LongTensor,
- starting_mask_ratio: int = 1,
+ starting_mask_ratio: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
- ) -> Union[AmusedSchedulerOutput, Tuple]:
+ ) -> Union[AmusedSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Predict the sample at the previous timestep by masking tokens based on confidence scores.
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from the learned diffusion model. Typically of shape `(batch_size, num_tokens,
+ codebook_size)` or `(batch_size, codebook_size, height, width)` for 2D inputs.
+ timestep (`int`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.LongTensor`):
+ A current instance of a sample created by the diffusion process. Contains token IDs, with masked
+ positions indicated by `mask_token_id`.
+ starting_mask_ratio (`float`, *optional*, defaults to 1.0):
+ A multiplier applied to the mask ratio schedule. Values less than 1.0 will result in fewer tokens being
+ masked at each step.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for reproducible sampling.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return an [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or a plain tuple.
+
+ Returns:
+ [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~schedulers.scheduling_amused.AmusedSchedulerOutput`] is returned,
+ otherwise a tuple is returned where the first element is the sample tensor (`prev_sample`) and the
+ second element is the predicted original sample tensor (`pred_original_sample`).
+ """
two_dim_input = sample.ndim == 3 and model_output.ndim == 4
if two_dim_input:
@@ -137,7 +230,27 @@ def step(
return AmusedSchedulerOutput(prev_sample, pred_original_sample)
- def add_noise(self, sample, timesteps, generator=None):
+ def add_noise(
+ self,
+ sample: torch.LongTensor,
+ timesteps: int,
+ generator: Optional[torch.Generator] = None,
+ ) -> torch.LongTensor:
+ """
+ Add noise to a sample by randomly masking tokens according to the masking schedule.
+
+ Args:
+ sample (`torch.LongTensor`):
+ The input sample containing token IDs to be partially masked.
+ timesteps (`int`):
+ The timestep that determines how much masking to apply. Higher timesteps result in more masking.
+ generator (`torch.Generator`, *optional*):
+ A random number generator for reproducible masking.
+
+ Returns:
+ `torch.LongTensor`:
+ The sample with some tokens replaced by `mask_token_id` according to the masking schedule.
+ """
step_idx = (self.timesteps == timesteps).nonzero()
ratio = (step_idx + 1) / len(self.timesteps)
diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py
index d7af018b284a..767fa9157f59 100644
--- a/src/diffusers/schedulers/scheduling_consistency_decoder.py
+++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py
@@ -1,6 +1,6 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Literal, Optional, Tuple, Union
import torch
@@ -12,10 +12,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -23,16 +23,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py
index 653171638ccf..386a43db0f9c 100644
--- a/src/diffusers/schedulers/scheduling_consistency_models.py
+++ b/src/diffusers/schedulers/scheduling_consistency_models.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -121,7 +121,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -203,8 +203,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
@@ -269,11 +268,7 @@ def get_scalings_for_boundary_condition(self, sigma):
Gets the scalings used in the consistency model parameterization (from Appendix C of the
[paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
-
-
- `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
-
-
+ > [!TIP] > `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
Args:
sigma (`torch.Tensor`):
@@ -292,7 +287,23 @@ def get_scalings_for_boundary_condition(self, sigma):
return c_skip, c_out
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -307,7 +318,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -415,6 +433,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
index ab56650dbac5..103cca81c6a5 100644
--- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,7 +30,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021).
This scheduler was used in Stable Audio Open [1].
- [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358
+ [1] Evans, Parker, et al. "Stable Audio Open" https://huggingface.co/papers/2407.14358
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
@@ -44,8 +44,8 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
sigma_schedule (`str`, *optional*, defaults to `exponential`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
- (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
- incorporated in this model: https://huggingface.co/stabilityai/cosxl.
+ (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
+ schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
@@ -53,7 +53,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `v_prediction`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -137,14 +137,14 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
- c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+ c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
@@ -266,6 +266,19 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -304,12 +317,8 @@ def convert_model_output(
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
@@ -420,7 +429,22 @@ def multistep_dpm_solver_second_order_update(
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index for a given timestep in the schedule.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The timestep for which to find the index.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule.
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -443,6 +467,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -541,6 +569,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -568,5 +611,10 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
+ def _get_conditioning_c_in(self, sigma):
+ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+ return c_in
+
def __len__(self):
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py
index 13c9b3b4a5e9..d7fe29a72ac9 100644
--- a/src/diffusers/schedulers/scheduling_ddim.py
+++ b/src/diffusers/schedulers/scheduling_ddim.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -92,17 +93,17 @@ def alpha_bar_fn(t):
return torch.tensor(betas, dtype=torch.float32)
-def rescale_zero_terminal_snr(betas):
+def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -143,9 +144,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Must be one
+ of `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
clip_sample (`bool`, defaults to `True`):
@@ -158,10 +159,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
otherwise it uses the alpha value at step 0.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`):
+ Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
+ process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -169,9 +170,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
- timestep_spacing (`str`, defaults to `"leading"`):
- The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ timestep_spacing (`Literal["leading", "trailing", "linspace"]`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Must be one of `"leading"`, `"trailing"`, or `"linspace"`. Refer to
+ Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891) for more information.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -187,17 +189,17 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -250,7 +252,25 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
"""
return sample
- def _get_variance(self, timestep, prev_timestep):
+ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
+ """
+ Computes the variance of the noise added at a given diffusion step.
+
+ For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
+ literature:
+ var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+ where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
+
+ Args:
+ timestep (`int`):
+ The current timestep in the diffusion process.
+ prev_timestep (`int`):
+ The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
+
+ Returns:
+ `torch.Tensor`:
+ The variance for the current timestep.
+ """
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -263,13 +283,23 @@ def _get_variance(self, timestep, prev_timestep):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -294,13 +324,18 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
return sample
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`Union[str, torch.device]`, *optional*):
+ The device to use for the timesteps.
+
+ Raises:
+ ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
"""
if num_inference_steps > self.config.num_train_timesteps:
@@ -312,7 +347,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.num_inference_steps = num_inference_steps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
@@ -346,7 +381,7 @@ def step(
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
- generator=None,
+ generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
@@ -357,20 +392,21 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`float`):
+ timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- eta (`float`):
- The weight of noise for added noise in diffusion step.
- use_clipped_model_output (`bool`, defaults to `False`):
+ eta (`float`, *optional*, defaults to 0.0):
+ The weight of noise for added noise in diffusion step. A value of 0 corresponds to DDIM (deterministic)
+ and 1 corresponds to DDPM (fully stochastic).
+ use_clipped_model_output (`bool`, *optional*, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
- A random number generator.
- variance_noise (`torch.Tensor`):
+ A random number generator for reproducible sampling.
+ variance_noise (`torch.Tensor`, *optional*):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
@@ -387,7 +423,7 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
# Ideally, read DDIM paper in-detail understanding
# Notation ( ->
@@ -408,7 +444,7 @@ def step(
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
@@ -441,10 +477,10 @@ def step(
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 7. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
@@ -477,6 +513,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -499,6 +551,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -517,5 +584,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
index 5c131752933c..f2683d1304ec 100644
--- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
+++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +18,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -94,7 +95,7 @@ def alpha_bar_fn(t):
def rescale_zero_terminal_snr(alphas_cumprod):
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
@@ -156,7 +157,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -275,7 +276,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.num_inference_steps = num_inference_steps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
@@ -350,7 +351,7 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
# Ideally, read DDIM paper in-detail understanding
# Notation ( ->
@@ -371,7 +372,7 @@ def step(
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
@@ -408,6 +409,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -430,6 +447,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py
index 23c71a61452a..802d8f79779d 100644
--- a/src/diffusers/schedulers/scheduling_ddim_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddim_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -73,7 +73,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
- For more details, see the original paper: https://arxiv.org/abs/2010.02502
+ For more details, see the original paper: https://huggingface.co/papers/2010.02502
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -230,7 +230,7 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
# Ideally, read DDIM paper in-detail understanding
# Notation ( ->
@@ -254,7 +254,7 @@ def step(
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
@@ -281,10 +281,10 @@ def step(
variance = self._get_variance(state, timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
- # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 5. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
- # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if not return_dict:
diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py
index d9d9ae683ad0..8ae13ad49d10 100644
--- a/src/diffusers/schedulers/scheduling_ddim_inverse.py
+++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -47,10 +47,10 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -58,16 +58,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -93,15 +94,15 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -159,7 +160,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -266,7 +267,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.num_inference_steps = num_inference_steps
- # "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "leading" and "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
@@ -338,7 +339,7 @@ def step(
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
@@ -360,10 +361,10 @@ def step(
-self.config.clip_sample_range, self.config.clip_sample_range
)
- # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 5. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
- # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if not return_dict:
diff --git a/src/diffusers/schedulers/scheduling_ddim_parallel.py b/src/diffusers/schedulers/scheduling_ddim_parallel.py
index 64412709ae90..10873a082fee 100644
--- a/src/diffusers/schedulers/scheduling_ddim_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddim_parallel.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class DDIMParallelSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -95,15 +96,15 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -139,7 +140,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
- For more details, see the original paper: https://arxiv.org/abs/2010.02502
+ For more details, see the original paper: https://huggingface.co/papers/2010.02502
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -163,23 +164,23 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
- https://imagen.research.google/video/paper.pdf)
+ https://huggingface.co/papers/2210.02303)
thresholding (`bool`, default `False`):
- whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
- Note that the thresholding method is unsuitable for latent-space diffusion models (such as
- stable-diffusion).
+ whether to use the "dynamic thresholding" method (introduced by Imagen,
+ https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
+ diffusion models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
- (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
+ (https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
- Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
rescale_betas_zero_snr (`bool`, default `False`):
- whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
- This can enable the model to generate very bright and dark samples instead of limiting it to samples with
- medium brightness. Loosely related to
+ whether to rescale the betas to have zero terminal SNR (proposed by
+ https://huggingface.co/papers/2305.08891). This can enable the model to generate very bright and dark
+ samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
@@ -194,17 +195,17 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: Literal["leading", "trailing", "linspace"] = "leading",
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -285,13 +286,23 @@ def _batch_get_variance(self, t, prev_t):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -324,6 +335,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`Union[str, torch.device]`, *optional*):
+ The device to use for the timesteps.
+
+ Raises:
+ ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`.
"""
if num_inference_steps > self.config.num_train_timesteps:
@@ -335,7 +351,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.num_inference_steps = num_inference_steps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
@@ -390,7 +406,7 @@ def step(
generator: random number generator.
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
- CycleDiffusion. (https://arxiv.org/abs/2210.05559)
+ CycleDiffusion. (https://huggingface.co/papers/2210.05559)
return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class
Returns:
@@ -404,7 +420,7 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
# Ideally, read DDIM paper in-detail understanding
# Notation ( ->
@@ -425,7 +441,7 @@ def step(
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
@@ -458,10 +474,10 @@ def step(
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 7. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
@@ -526,7 +542,7 @@ def batch_step_no_noise(
assert eta == 0.0
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
# Ideally, read DDIM paper in-detail understanding
# Notation ( ->
@@ -554,7 +570,7 @@ def batch_step_no_noise(
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
@@ -587,10 +603,10 @@ def batch_step_no_noise(
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
- # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
- # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 7. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
return prev_sample
@@ -602,6 +618,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -624,6 +656,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py
index 624d5a5cd4f3..ded88b8e1e0a 100644
--- a/src/diffusers/schedulers/scheduling_ddpm.py
+++ b/src/diffusers/schedulers/scheduling_ddpm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 UC Berkeley Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -46,10 +46,10 @@ class DDPMSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -57,16 +57,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -90,17 +91,17 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
-def rescale_zero_terminal_snr(betas):
+def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -134,39 +135,37 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- num_train_timesteps (`int`, defaults to 1000):
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
+ beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
+ beta_end (`float`, defaults to `0.02`):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `sigmoid`.
+ beta_schedule (`"linear"`, `"scaled_linear"`, `"squaredcos_cap_v2"`, or `"sigmoid"`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
An array of betas to pass directly to the constructor without using `beta_start` and `beta_end`.
- variance_type (`str`, defaults to `"fixed_small"`):
- Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`,
- `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, defaults to `"fixed_small"`):
+ Clip the variance when adding noise to the denoised sample.
clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability.
- clip_sample_range (`float`, defaults to 1.0):
+ clip_sample_range (`float`, defaults to `1.0`):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ dynamic_thresholding_ratio (`float`, defaults to `0.995`):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
- sample_max_value (`float`, defaults to 1.0):
+ sample_max_value (`float`, defaults to `1.0`):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
- timestep_spacing (`str`, defaults to `"leading"`):
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- steps_offset (`int`, defaults to 0):
+ steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
@@ -183,16 +182,18 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- variance_type: str = "fixed_small",
+ variance_type: Literal[
+ "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"
+ ] = "fixed_small",
clip_sample: bool = True,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
@@ -279,8 +280,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
@@ -296,7 +296,7 @@ def set_timesteps(
self.num_inference_steps = num_inference_steps
self.custom_timesteps = False
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
@@ -323,14 +323,38 @@ def set_timesteps(
self.timesteps = torch.from_numpy(timesteps).to(device)
- def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ def _get_variance(
+ self,
+ t: int,
+ predicted_variance: Optional[torch.Tensor] = None,
+ variance_type: Optional[
+ Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"]
+ ] = None,
+ ) -> torch.Tensor:
+ """
+ Compute the variance for a given timestep according to the specified variance type.
+
+ Args:
+ t (`int`):
+ The current timestep.
+ predicted_variance (`torch.Tensor`, *optional*):
+ The predicted variance from the model. Used only when `variance_type` is `"learned"` or
+ `"learned_range"`.
+ variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, *optional*):
+ The type of variance to compute. If `None`, uses the variance type specified in the scheduler
+ configuration.
+
+ Returns:
+ `torch.Tensor`:
+ The computed variance.
+ """
prev_t = self.previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
- # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://huggingface.co/papers/2006.11239)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
@@ -344,7 +368,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = variance
- # for rl-diffuser https://arxiv.org/abs/2205.09991
+ # for rl-diffuser https://huggingface.co/papers/2205.09991
elif variance_type == "fixed_small_log":
variance = torch.log(variance)
variance = torch.exp(0.5 * variance)
@@ -365,13 +389,23 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -401,7 +435,7 @@ def step(
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
- generator=None,
+ generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
"""
@@ -411,20 +445,19 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`float`):
+ timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
- return_dict (`bool`, *optional*, defaults to `True`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
-
"""
t = timestep
@@ -444,7 +477,7 @@ def step(
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ # "predicted x_0" of formula (15) from https://huggingface.co/papers/2006.11239
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
@@ -466,12 +499,12 @@ def step(
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
@@ -505,6 +538,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -526,6 +575,21 @@ def add_noise(
return noisy_samples
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -544,10 +608,21 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
- def previous_timestep(self, timestep):
+ def previous_timestep(self, timestep: int) -> int:
+ """
+ Compute the previous timestep in the diffusion chain.
+
+ Args:
+ timestep (`int`):
+ The current timestep.
+
+ Returns:
+ `int`:
+ The previous timestep.
+ """
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py
index d06a171159ee..a3264f54f572 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 UC Berkeley Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,7 +61,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
- For more details, see the original paper: https://arxiv.org/abs/2006.11239
+ For more details, see the original paper: https://huggingface.co/papers/2006.11239
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -163,7 +163,7 @@ def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, v
alpha_prod_t = state.common.alphas_cumprod[t]
alpha_prod_t_prev = jnp.where(t > 0, state.common.alphas_cumprod[t - 1], jnp.array(1.0, dtype=self.dtype))
- # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://huggingface.co/papers/2006.11239)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * state.common.betas[t]
@@ -174,7 +174,7 @@ def _get_variance(self, state: DDPMSchedulerState, t, predicted_variance=None, v
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = jnp.clip(variance, a_min=1e-20)
- # for rl-diffuser https://arxiv.org/abs/2205.09991
+ # for rl-diffuser https://huggingface.co/papers/2205.09991
elif variance_type == "fixed_small_log":
variance = jnp.log(jnp.clip(variance, a_min=1e-20))
elif variance_type == "fixed_large":
@@ -240,7 +240,7 @@ def step(
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 2. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ # "predicted x_0" of formula (15) from https://huggingface.co/papers/2006.11239
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
@@ -258,12 +258,12 @@ def step(
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * state.common.betas[t]) / beta_prod_t
current_sample_coeff = state.common.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
index 20ad7a4c927d..941fc16be080 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -48,10 +48,10 @@ class DDPMParallelSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -59,16 +59,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -94,15 +95,15 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -138,7 +139,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
- For more details, see the original paper: https://arxiv.org/abs/2006.11239
+ For more details, see the original paper: https://huggingface.co/papers/2006.11239
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -159,19 +160,19 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
- https://imagen.research.google/video/paper.pdf)
+ https://huggingface.co/papers/2210.02303)
thresholding (`bool`, default `False`):
- whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
- Note that the thresholding method is unsuitable for latent-space diffusion models (such as
- stable-diffusion).
+ whether to use the "dynamic thresholding" method (introduced by Imagen,
+ https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
+ diffusion models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
- (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
+ (https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
- Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
+ Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, default `0`):
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
@@ -191,16 +192,18 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- variance_type: str = "fixed_small",
+ variance_type: Literal[
+ "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"
+ ] = "fixed_small",
clip_sample: bool = True,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
- timestep_spacing: str = "leading",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
@@ -289,8 +292,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
@@ -306,7 +308,7 @@ def set_timesteps(
self.num_inference_steps = num_inference_steps
self.custom_timesteps = False
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
@@ -334,14 +336,38 @@ def set_timesteps(
self.timesteps = torch.from_numpy(timesteps).to(device)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._get_variance
- def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ def _get_variance(
+ self,
+ t: int,
+ predicted_variance: Optional[torch.Tensor] = None,
+ variance_type: Optional[
+ Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"]
+ ] = None,
+ ) -> torch.Tensor:
+ """
+ Compute the variance for a given timestep according to the specified variance type.
+
+ Args:
+ t (`int`):
+ The current timestep.
+ predicted_variance (`torch.Tensor`, *optional*):
+ The predicted variance from the model. Used only when `variance_type` is `"learned"` or
+ `"learned_range"`.
+ variance_type (`"fixed_small"`, `"fixed_small_log"`, `"fixed_large"`, `"fixed_large_log"`, `"learned"`, or `"learned_range"`, *optional*):
+ The type of variance to compute. If `None`, uses the variance type specified in the scheduler
+ configuration.
+
+ Returns:
+ `torch.Tensor`:
+ The computed variance.
+ """
prev_t = self.previous_timestep(t)
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
- # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://huggingface.co/papers/2006.11239)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
@@ -355,7 +381,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
# hacks - were probably added for training stability
if variance_type == "fixed_small":
variance = variance
- # for rl-diffuser https://arxiv.org/abs/2205.09991
+ # for rl-diffuser https://huggingface.co/papers/2205.09991
elif variance_type == "fixed_small_log":
variance = torch.log(variance)
variance = torch.exp(0.5 * variance)
@@ -377,13 +403,23 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -452,7 +488,7 @@ def step(
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ # "predicted x_0" of formula (15) from https://huggingface.co/papers/2006.11239
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
@@ -474,12 +510,12 @@ def step(
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
@@ -555,7 +591,7 @@ def batch_step_no_noise(
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ # "predicted x_0" of formula (15) from https://huggingface.co/papers/2006.11239
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
@@ -577,12 +613,12 @@ def batch_step_no_noise(
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
return pred_prev_sample
@@ -594,6 +630,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -616,6 +668,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -639,6 +706,17 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
+ """
+ Compute the previous timestep in the diffusion chain.
+
+ Args:
+ timestep (`int`):
+ The current timestep.
+
+ Returns:
+ `int`:
+ The previous timestep.
+ """
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py
index 71b5669b0528..71f08277ebd7 100644
--- a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py
+++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py
@@ -1,5 +1,5 @@
# Copyright (c) 2022 Pablo Pernías MIT License
-# Copyright 2024 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 UC Berkeley Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -95,7 +95,7 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
- For more details, see the original paper: https://arxiv.org/abs/2006.11239
+ For more details, see the original paper: https://huggingface.co/papers/2006.11239
Args:
scaler (`float`): ....
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index 6a653f183bba..b7d64fc00bae 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -1,4 +1,4 @@
-# Copyright 2024 FLAIR Lab and The HuggingFace Team. All rights reserved.
+# Copyright 2025 FLAIR Lab and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info
+# DISCLAIMER: check https://huggingface.co/papers/2204.13902 and https://github.com/qsh-zh/deis for more info
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,10 +32,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -43,16 +43,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -83,33 +84,35 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- num_train_timesteps (`int`, defaults to 1000):
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
+ beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
+ beta_end (`float`, defaults to `0.02`):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
- trained_betas (`np.ndarray`, *optional*):
+ trained_betas (`np.ndarray` or `List[float]`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- solver_order (`int`, defaults to 2):
+ solver_order (`int`, defaults to `2`):
The DEIS order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
- prediction_type (`str`, defaults to `epsilon`):
+ prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
+ Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ dynamic_thresholding_ratio (`float`, defaults to `0.995`):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
- sample_max_value (`float`, defaults to 1.0):
+ sample_max_value (`float`, defaults to `1.0`):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
- algorithm_type (`str`, defaults to `deis`):
+ algorithm_type (`"deis"`, defaults to `"deis"`):
The algorithm type for the solver.
+ solver_type (`"logrho"`, defaults to `"logrho"`):
+ Solver type for DEIS.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -120,11 +123,19 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
- timestep_spacing (`str`, defaults to `"linspace"`):
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
+ flow_shift (`float`, *optional*, defaults to `1.0`):
+ The flow shift parameter for flow-based models.
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- steps_offset (`int`, defaults to 0):
+ steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
+ use_dynamic_shifting (`bool`, defaults to `False`):
+ Whether to use dynamic shifting for the noise schedule.
+ time_shift_type (`"exponential"`, defaults to `"exponential"`):
+ The type of time shifting to apply.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -136,27 +147,38 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
- trained_betas: Optional[np.ndarray] = None,
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
- algorithm_type: str = "deis",
- solver_type: str = "logrho",
+ algorithm_type: Literal["deis"] = "deis",
+ solver_type: Literal["logrho"] = "logrho",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
- timestep_spacing: str = "linspace",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
- ):
+ use_dynamic_shifting: bool = False,
+ time_shift_type: Literal["exponential"] = "exponential",
+ ) -> None:
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
- if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
+ if (
+ sum(
+ [
+ self.config.use_beta_sigmas,
+ self.config.use_exponential_sigmas,
+ self.config.use_karras_sigmas,
+ ]
+ )
+ > 1
+ ):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
@@ -166,7 +188,15 @@ def __init__(
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
- self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ self.betas = (
+ torch.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_train_timesteps,
+ dtype=torch.float32,
+ )
+ ** 2
+ )
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -208,31 +238,36 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(
+ self,
+ num_inference_steps: int,
+ device: Union[str, torch.device] = None,
+ mu: Optional[float] = None,
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -241,8 +276,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ mu (`float`, *optional*):
+ The mu parameter for dynamic shifting. Only used when `use_dynamic_shifting=True` and
+ `time_shift_type="exponential"`.
"""
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ if mu is not None:
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
+ self.config.flow_shift = np.exp(mu)
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
@@ -313,13 +354,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -345,7 +396,20 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -369,7 +433,18 @@ def _sigma_to_t(self, sigma, log_sigmas):
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
- def _sigma_to_alpha_sigma_t(self, sigma):
+ def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert sigma values to alpha_t and sigma_t values.
+
+ Args:
+ sigma (`torch.Tensor`):
+ The sigma value(s) to convert.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ A tuple containing (alpha_t, sigma_t) values.
+ """
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -380,8 +455,21 @@ def _sigma_to_alpha_sigma_t(self, sigma):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
- def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -407,7 +495,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -431,7 +531,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -486,7 +603,7 @@ def convert_model_output(
if len(args) > 1:
sample = args[1]
else:
- raise ValueError("missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -549,7 +666,7 @@ def deis_first_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -564,7 +681,10 @@ def deis_first_order_update(
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
- sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
+ sigma_t, sigma_s = (
+ self.sigmas[self.step_index + 1],
+ self.sigmas[self.step_index],
+ )
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -603,7 +723,7 @@ def multistep_deis_second_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -630,7 +750,11 @@ def multistep_deis_second_order_update(
m0, m1 = model_output_list[-1], model_output_list[-2]
- rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1
+ rho_t, rho_s0, rho_s1 = (
+ sigma_t / alpha_t,
+ sigma_s0 / alpha_s0,
+ sigma_s1 / alpha_s1,
+ )
if self.config.algorithm_type == "deis":
@@ -673,7 +797,7 @@ def multistep_deis_third_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing`sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -735,7 +859,22 @@ def ind_fn(t, b, c, d):
raise NotImplementedError("only support log-rho multistep deis now")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index for a given timestep in the schedule.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The timestep for which to find the index.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule.
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -755,9 +894,13 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -781,18 +924,17 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`int`):
+ timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
-
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -854,6 +996,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples without noise.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps at which to add noise to the samples.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -882,5 +1039,5 @@ def add_noise(
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
index 1a2c7be7115b..0a9082208cf4 100644
--- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
+++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +18,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -50,10 +50,10 @@ class DDIMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -61,16 +61,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -95,7 +96,7 @@ def alpha_bar_fn(t):
def rescale_zero_terminal_snr(alphas_cumprod):
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
@@ -157,7 +158,7 @@ class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -276,7 +277,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.num_inference_steps = num_inference_steps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
@@ -377,7 +378,7 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
- # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502
# Ideally, read DDIM paper in-detail understanding
# Notation ( ->
@@ -399,7 +400,7 @@ def step(
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
@@ -445,6 +446,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -467,6 +484,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index ed60dd4eaee1..e7ba0ba1f30e 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,10 +32,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -43,16 +43,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -78,15 +79,15 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -126,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
solver_order (`int`, defaults to 2):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
+ prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
+ Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
+ directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
+ Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -146,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
- algorithm_type (`str`, defaults to `dpmsolver++`):
- Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
- `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
- paper, and the `dpmsolver++` type implements the algorithms in the
- [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
- `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
- solver_type (`str`, defaults to `midpoint`):
- Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
- sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
+ algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
+ Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
+ [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
+ algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
+ `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
+ solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
+ Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
+ for a small number of steps. It is recommended to use `midpoint` solvers.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
@@ -178,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
flow_shift (`float`, *optional*, defaults to 1.0):
The shift value for the timestep schedule for flow matching.
- final_sigmas_type (`str`, defaults to `"zero"`):
+ final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
- variance_type (`str`, *optional*):
- Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
- contains the predicted Gaussian variance.
- timestep_spacing (`str`, defaults to `"linspace"`):
+ variance_type (`"learned"` or `"learned_range"`, *optional*):
+ Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
+ output contains the predicted Gaussian variance.
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
@@ -196,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ use_dynamic_shifting (`bool`, defaults to `False`):
+ Whether to use dynamic shifting for the timestep schedule.
+ time_shift_type (`"exponential"`, defaults to `"exponential"`):
+ The type of time shift to apply when using dynamic shifting.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -207,15 +210,15 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
- algorithm_type: str = "dpmsolver++",
- solver_type: str = "midpoint",
+ algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
+ solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
@@ -224,12 +227,14 @@ def __init__(
use_lu_lambdas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
lambda_min_clipped: float = -float("inf"),
- variance_type: Optional[str] = None,
- timestep_spacing: str = "linspace",
+ variance_type: Optional[Literal["learned", "learned_range"]] = None,
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
+ use_dynamic_shifting: bool = False,
+ time_shift_type: Literal["exponential"] = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -321,30 +326,37 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(
self,
- num_inference_steps: int = None,
- device: Union[str, torch.device] = None,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
- ):
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
+ num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ mu (`float`, *optional*):
+ The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
+ `time_shift_type="exponential"`.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
"""
+ if mu is not None:
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
+ self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
@@ -366,7 +378,7 @@ def set_timesteps(
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1)
@@ -454,13 +466,23 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -486,7 +508,20 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -509,7 +544,18 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t
- def _sigma_to_alpha_sigma_t(self, sigma):
+ def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert sigma values to alpha_t and sigma_t values.
+
+ Args:
+ sigma (`torch.Tensor`):
+ The sigma value(s) to convert.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ A tuple containing (alpha_t, sigma_t) values.
+ """
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -521,7 +567,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -545,8 +604,21 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
- def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Lu et al. (2022)."""
+ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """
+ Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
+ Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
+
+ Args:
+ in_lambdas (`torch.Tensor`):
+ The input lambda values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted lambda values following the Lu noise schedule.
+ """
lambda_min: float = in_lambdas[-1].item()
lambda_max: float = in_lambdas[0].item()
@@ -560,7 +632,19 @@ def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -584,7 +668,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -624,12 +725,8 @@ def convert_model_output(
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
@@ -646,7 +743,7 @@ def convert_model_output(
if len(args) > 1:
sample = args[1]
else:
- raise ValueError("missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -741,7 +838,7 @@ def dpm_solver_first_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -810,7 +907,7 @@ def multistep_dpm_solver_second_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -845,7 +942,7 @@ def multistep_dpm_solver_second_order_update(
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ # See https://huggingface.co/papers/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample
@@ -859,7 +956,7 @@ def multistep_dpm_solver_second_order_update(
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
@@ -934,7 +1031,7 @@ def multistep_dpm_solver_third_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing`sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -975,7 +1072,7 @@ def multistep_dpm_solver_third_order_update(
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
@@ -983,7 +1080,7 @@ def multistep_dpm_solver_third_order_update(
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
@@ -1001,7 +1098,22 @@ def multistep_dpm_solver_third_order_update(
)
return x_t
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index for a given timestep in the schedule.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The timestep for which to find the index.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule.
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -1020,9 +1132,13 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -1037,7 +1153,7 @@ def step(
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
- generator=None,
+ generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
@@ -1047,22 +1163,22 @@ def step(
Args:
model_output (`torch.Tensor`):
- The direct output from learned diffusion model.
- timestep (`int`):
+ The direct output from the learned diffusion model.
+ timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
- variance_noise (`torch.Tensor`):
+ variance_noise (`torch.Tensor`, *optional*):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`LEdits++`].
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
+ If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
@@ -1142,6 +1258,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples without noise.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps at which to add noise to the samples.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
index 3f48066455fb..71b9960bf2ff 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -80,14 +80,15 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality
samples, and it can generate quite good samples even in only 10 steps.
- For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+ For more details, see the original paper: https://huggingface.co/papers/2206.00927 and
+ https://huggingface.co/papers/2211.01095
Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We
recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling.
- We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space
- diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic
- thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ We also support the "dynamic thresholding" method in Imagen (https://huggingface.co/papers/2205.11487). For
+ pixel-space diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the
+ dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
@@ -95,7 +96,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
- For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095
+ For more details, see the original paper: https://huggingface.co/papers/2206.00927 and
+ https://huggingface.co/papers/2211.01095
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -113,21 +115,21 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
or `v-prediction`.
thresholding (`bool`, default `False`):
- whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
- For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
- use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion
- models (such as stable-diffusion).
+ whether to use the "dynamic thresholding" method (introduced by Imagen,
+ https://huggingface.co/papers/2205.11487). For pixel-space diffusion models, you can set both
+ `algorithm_type=dpmsolver++` and `thresholding=True` to use the dynamic thresholding. Note that the
+ thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
- (https://arxiv.org/abs/2205.11487).
+ (https://huggingface.co/papers/2205.11487).
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++`.
algorithm_type (`str`, default `dpmsolver++`):
the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
- algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in
- https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided
- sampling (e.g. stable-diffusion).
+ algorithms in https://huggingface.co/papers/2206.00927, and the `dpmsolver++` type implements the
+ algorithms in https://huggingface.co/papers/2211.01095. We recommend to use `dpmsolver++` with
+ `solver_order=2` for guided sampling (e.g. stable-diffusion).
solver_type (`str`, default `midpoint`):
the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects
the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are
@@ -297,7 +299,7 @@ def convert_model_output(
)
if self.config.thresholding:
- # Dynamic thresholding in https://arxiv.org/abs/2205.11487
+ # Dynamic thresholding in https://huggingface.co/papers/2205.11487
dynamic_max_val = jnp.percentile(
jnp.abs(x0_pred), self.config.dynamic_thresholding_ratio, axis=tuple(range(1, x0_pred.ndim))
)
@@ -335,7 +337,7 @@ def dpm_solver_first_order_update(
"""
One step for the first-order DPM-Solver (equivalent to DDIM).
- See https://arxiv.org/abs/2206.00927 for the detailed derivation.
+ See https://huggingface.co/papers/2206.00927 for the detailed derivation.
Args:
model_output (`jnp.ndarray`): direct output from learned diffusion model.
@@ -390,7 +392,7 @@ def multistep_dpm_solver_second_order_update(
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ # See https://huggingface.co/papers/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample
@@ -404,7 +406,7 @@ def multistep_dpm_solver_second_order_update(
+ (alpha_t * ((jnp.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
@@ -458,7 +460,7 @@ def multistep_dpm_solver_third_order_update(
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (jnp.exp(-h) - 1.0)) * D0
@@ -466,7 +468,7 @@ def multistep_dpm_solver_third_order_update(
- (alpha_t * ((jnp.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (jnp.exp(h) - 1.0)) * D0
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index 971817f7b777..6696b0375f9f 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,10 +32,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -43,16 +43,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -100,7 +101,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -257,7 +258,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped).item()
self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64)
@@ -332,13 +333,23 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -365,6 +376,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -389,6 +413,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
+ """
+ Convert sigma values to alpha_t and sigma_t values.
+
+ Args:
+ sigma (`torch.Tensor`):
+ The sigma value(s) to convert.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ A tuple containing (alpha_t, sigma_t) values.
+ """
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -400,7 +435,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -426,7 +474,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -450,7 +510,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -491,12 +568,8 @@ def convert_model_output(
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
@@ -513,7 +586,7 @@ def convert_model_output(
if len(args) > 1:
sample = args[1]
else:
- raise ValueError("missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -609,7 +682,7 @@ def dpm_solver_first_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -679,7 +752,7 @@ def multistep_dpm_solver_second_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -714,7 +787,7 @@ def multistep_dpm_solver_second_order_update(
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ # See https://huggingface.co/papers/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample
@@ -728,7 +801,7 @@ def multistep_dpm_solver_second_order_update(
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s0) * sample
@@ -804,7 +877,7 @@ def multistep_dpm_solver_third_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing`sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -845,7 +918,7 @@ def multistep_dpm_solver_third_order_update(
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
@@ -853,7 +926,7 @@ def multistep_dpm_solver_third_order_update(
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
x_t = (
(alpha_t / alpha_s0) * sample
- (sigma_t * (torch.exp(h) - 1.0)) * D0
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
index 6c9cb975fe34..81c9e4134f57 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+# Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -115,10 +115,10 @@ def __call__(self, sigma, sigma_next):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -126,16 +126,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -181,7 +182,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
@@ -250,7 +251,23 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -265,7 +282,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -301,7 +325,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -352,7 +376,7 @@ def set_timesteps(
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
elif self.config.timestep_spacing == "leading":
@@ -429,6 +453,19 @@ def t_fn(_sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -451,9 +488,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t
- # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ # Copied from diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
@@ -467,7 +515,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -491,7 +551,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -645,6 +722,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index bf68d6c99bd6..4916e1abb549 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -34,10 +34,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -45,16 +45,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -85,42 +86,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- num_train_timesteps (`int`, defaults to 1000):
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
+ beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
+ beta_end (`float`, defaults to `0.02`):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
- trained_betas (`np.ndarray`, *optional*):
+ trained_betas (`np.ndarray` or `List[float]`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- solver_order (`int`, defaults to 2):
+ solver_order (`int`, defaults to `2`):
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
+ Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
- dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ dynamic_thresholding_ratio (`float`, defaults to `0.995`):
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
- sample_max_value (`float`, defaults to 1.0):
+ sample_max_value (`float`, defaults to `1.0`):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
- algorithm_type (`str`, defaults to `dpmsolver++`):
- Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
+ algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
+ Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, or `sde-dpmsolver++`. The `dpmsolver`
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
sampling like in Stable Diffusion.
- solver_type (`str`, defaults to `midpoint`):
+ solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
- lower_order_final (`bool`, defaults to `True`):
+ lower_order_final (`bool`, defaults to `False`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -131,15 +132,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
- final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
+ flow_shift (`float`, *optional*, defaults to `1.0`):
+ The flow shift parameter for flow-based models.
+ final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
- variance_type (`str`, *optional*):
- Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
- contains the predicted Gaussian variance.
+ variance_type (`"learned"` or `"learned_range"`, *optional*):
+ Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
+ output contains the predicted Gaussian variance.
+ use_dynamic_shifting (`bool`, defaults to `False`):
+ Whether to use dynamic shifting for the noise schedule.
+ time_shift_type (`"exponential"`, defaults to `"exponential"`):
+ The type of time shifting to apply.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -151,25 +160,27 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
- trained_betas: Optional[np.ndarray] = None,
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
- algorithm_type: str = "dpmsolver++",
- solver_type: str = "midpoint",
+ algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
+ solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
lambda_min_clipped: float = -float("inf"),
- variance_type: Optional[str] = None,
- ):
+ variance_type: Optional[Literal["learned", "learned_range"]] = None,
+ use_dynamic_shifting: bool = False,
+ time_shift_type: Literal["exponential"] = "exponential",
+ ) -> None:
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -218,7 +229,7 @@ def __init__(
if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero":
raise ValueError(
- f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please chooose `sigma_min` instead."
+ f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead."
)
# setable values
@@ -239,6 +250,10 @@ def get_order_list(self, num_inference_steps: int) -> List[int]:
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
+
+ Returns:
+ `List[int]`:
+ The list of solver orders for each timestep.
"""
steps = num_inference_steps
order = self.config.solver_order
@@ -273,49 +288,63 @@ def get_order_list(self, num_inference_steps: int) -> List[int]:
return orders
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
+
+ Returns:
+ `int` or `None`:
+ The current step index.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+
+ Returns:
+ `int` or `None`:
+ The begin index.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(
self,
- num_inference_steps: int = None,
- device: Union[str, torch.device] = None,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
- ):
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
+ num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ mu (`float`, *optional*):
+ The mu parameter for dynamic shifting.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
passed, `num_inference_steps` must be `None`.
"""
+ if mu is not None:
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
+ self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
@@ -404,13 +433,23 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -436,7 +475,20 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -460,7 +512,18 @@ def _sigma_to_t(self, sigma, log_sigmas):
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
- def _sigma_to_alpha_sigma_t(self, sigma):
+ def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert sigma values to alpha_t and sigma_t values.
+
+ Args:
+ sigma (`torch.Tensor`):
+ The sigma value(s) to convert.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ A tuple containing (alpha_t, sigma_t) values.
+ """
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -471,8 +534,21 @@ def _sigma_to_alpha_sigma_t(self, sigma):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
- def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -498,7 +574,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -522,7 +610,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -554,7 +659,7 @@ def convert_model_output(
self,
model_output: torch.Tensor,
*args,
- sample: torch.Tensor = None,
+ sample: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -562,12 +667,8 @@ def convert_model_output(
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
@@ -584,7 +685,7 @@ def convert_model_output(
if len(args) > 1:
sample = args[1]
else:
- raise ValueError("missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -654,7 +755,7 @@ def dpm_solver_first_order_update(
self,
model_output: torch.Tensor,
*args,
- sample: torch.Tensor = None,
+ sample: Optional[torch.Tensor] = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -681,7 +782,7 @@ def dpm_solver_first_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -718,7 +819,7 @@ def singlestep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
- sample: torch.Tensor = None,
+ sample: Optional[torch.Tensor] = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -746,7 +847,7 @@ def singlestep_dpm_solver_second_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -780,7 +881,7 @@ def singlestep_dpm_solver_second_order_update(
r0 = h_0 / h
D0, D1 = m1, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ # See https://huggingface.co/papers/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s1) * sample
@@ -794,7 +895,7 @@ def singlestep_dpm_solver_second_order_update(
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
)
elif self.config.algorithm_type == "dpmsolver":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s1) * sample
@@ -829,7 +930,7 @@ def singlestep_dpm_solver_third_order_update(
self,
model_output_list: List[torch.Tensor],
*args,
- sample: torch.Tensor = None,
+ sample: Optional[torch.Tensor] = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -858,7 +959,7 @@ def singlestep_dpm_solver_third_order_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing`sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -899,7 +1000,7 @@ def singlestep_dpm_solver_third_order_update(
D1 = (r0 * D1_0 - r1 * D1_1) / (r0 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r0 - r1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s2) * sample
@@ -914,7 +1015,7 @@ def singlestep_dpm_solver_third_order_update(
- (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2
)
elif self.config.algorithm_type == "dpmsolver":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(alpha_t / alpha_s2) * sample
@@ -951,8 +1052,8 @@ def singlestep_dpm_solver_update(
self,
model_output_list: List[torch.Tensor],
*args,
- sample: torch.Tensor = None,
- order: int = None,
+ sample: Optional[torch.Tensor] = None,
+ order: Optional[int] = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -981,12 +1082,12 @@ def singlestep_dpm_solver_update(
if len(args) > 2:
sample = args[2]
else:
- raise ValueError(" missing`sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
- raise ValueError(" missing `order` as a required keyward argument")
+ raise ValueError("missing `order` as a required keyword argument")
if timestep_list is not None:
deprecate(
"timestep_list",
@@ -1011,7 +1112,22 @@ def singlestep_dpm_solver_update(
raise ValueError(f"Order must be 1, 2, 3, got {order}")
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index for a given timestep in the schedule.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The timestep for which to find the index.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule.
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -1031,9 +1147,13 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -1048,7 +1168,7 @@ def step(
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
- generator=None,
+ generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
@@ -1058,11 +1178,13 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`int`):
+ timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ generator (`torch.Generator`, *optional*):
+ A random number generator for stochastic sampling.
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
@@ -1136,6 +1258,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples without noise.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps at which to add noise to the samples.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -1164,5 +1301,5 @@ def add_noise(
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
index c49e8e9a191a..d4e8ca5e8b18 100644
--- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`EDMDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
- https://arxiv.org/abs/2206.00364
+ https://huggingface.co/papers/2206.00364
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
@@ -47,8 +47,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
sigma_schedule (`str`, *optional*, defaults to `karras`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
- (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
- incorporated in this model: https://huggingface.co/stabilityai/cosxl.
+ (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
+ schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
@@ -57,7 +57,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -169,14 +169,14 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
- c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+ c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
@@ -299,13 +299,23 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -332,6 +342,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -370,12 +393,8 @@ def convert_model_output(
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
@@ -472,7 +491,7 @@ def multistep_dpm_solver_second_order_update(
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2211.01095 for detailed derivations
+ # See https://huggingface.co/papers/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
x_t = (
(sigma_t / sigma_s0) * sample
@@ -548,7 +567,7 @@ def multistep_dpm_solver_third_order_update(
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.config.algorithm_type == "dpmsolver++":
- # See https://arxiv.org/abs/2206.00927 for detailed derivations
+ # See https://huggingface.co/papers/2206.00927 for detailed derivations
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
@@ -559,7 +578,22 @@ def multistep_dpm_solver_third_order_update(
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index for a given timestep in the schedule.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The timestep for which to find the index.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule.
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -582,6 +616,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -676,6 +714,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -703,5 +756,10 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
+ # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
+ def _get_conditioning_c_in(self, sigma):
+ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+ return c_in
+
def __len__(self):
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py
index 0617cc44d75a..2ed05d396514 100644
--- a/src/diffusers/schedulers/scheduling_edm_euler.py
+++ b/src/diffusers/schedulers/scheduling_edm_euler.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -51,7 +51,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1].
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
- https://arxiv.org/abs/2206.00364
+ https://huggingface.co/papers/2206.00364
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
@@ -67,14 +67,14 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
sigma_schedule (`str`, *optional*, defaults to `karras`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
- (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
- incorporated in this model: https://huggingface.co/stabilityai/cosxl.
+ (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
+ schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
rho (`float`, *optional*, defaults to 7.0):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
final_sigmas_type (`str`, defaults to `"zero"`):
@@ -103,11 +103,13 @@ def __init__(
# setable values
self.num_inference_steps = None
- sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
+ sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+ sigmas = torch.arange(num_train_timesteps + 1, dtype=sigmas_dtype) / num_train_timesteps
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(sigmas)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(sigmas)
+ sigmas = sigmas.to(torch.float32)
self.timesteps = self.precondition_noise(sigmas)
@@ -153,13 +155,13 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def precondition_inputs(self, sample, sigma):
- c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+ c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
@@ -230,18 +232,19 @@ def set_timesteps(
"""
self.num_inference_steps = num_inference_steps
+ sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
if sigmas is None:
- sigmas = torch.linspace(0, 1, self.num_inference_steps)
+ sigmas = torch.linspace(0, 1, self.num_inference_steps, dtype=sigmas_dtype)
elif isinstance(sigmas, float):
- sigmas = torch.tensor(sigmas, dtype=torch.float32)
+ sigmas = torch.tensor(sigmas, dtype=sigmas_dtype)
else:
- sigmas = sigmas
+ sigmas = sigmas.to(sigmas_dtype)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(sigmas)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(sigmas)
-
sigmas = sigmas.to(dtype=torch.float32, device=device)
+
self.timesteps = self.precondition_noise(sigmas)
if self.config.final_sigmas_type == "sigma_min":
@@ -281,7 +284,23 @@ def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> t
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -296,7 +315,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -315,6 +341,7 @@ def step(
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
+ pred_original_sample: Optional[torch.Tensor] = None,
) -> Union[EDMEulerSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
@@ -378,7 +405,8 @@ def step(
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
- pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
+ if pred_original_sample is None:
+ pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
@@ -408,6 +436,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -435,5 +478,9 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
+ def _get_conditioning_c_in(self, sigma):
+ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
+ return c_in
+
def __len__(self):
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
index 4df43a160ce1..97fd84db5621 100644
--- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -93,17 +94,17 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
-def rescale_zero_terminal_snr(betas):
+def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -143,16 +144,16 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear` or `scaled_linear`.
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
- timestep_spacing (`str`, defaults to `"linspace"`):
+ Video](https://huggingface.co/papers/2210.02303) paper).
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
@@ -172,13 +173,13 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- prediction_type: str = "epsilon",
- timestep_spacing: str = "linspace",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
- ):
+ ) -> None:
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -218,7 +219,7 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def init_noise_sigma(self):
+ def init_noise_sigma(self) -> torch.Tensor:
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
@@ -226,26 +227,26 @@ def init_noise_sigma(self):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -258,7 +259,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
Args:
sample (`torch.Tensor`):
The input sample.
- timestep (`int`, *optional*):
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -274,7 +275,7 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -286,7 +287,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
"""
self.num_inference_steps = num_inference_steps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1
@@ -319,7 +320,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -334,7 +351,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -357,13 +381,13 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`float`):
+ timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
@@ -451,6 +475,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -478,5 +517,5 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py
index 56757f3ca197..a55a76626cec 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -52,10 +52,10 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -63,16 +63,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -96,17 +97,17 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
-def rescale_zero_terminal_snr(betas):
+def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -146,17 +147,17 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
+ beta_schedule (`Literal["linear", "scaled_linear", "squaredcos_cap_v2"]`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear` or `scaled_linear`.
+ `"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
- Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
- interpolation_type(`str`, defaults to `"linear"`, *optional*):
- The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be on of
+ prediction_type (`Literal["epsilon", "sample", "v_prediction"]`, defaults to `"epsilon"`, *optional*):
+ Prediction type of the scheduler function; can be `"epsilon"` (predicts the noise of the diffusion
+ process), `"sample"` (directly predicts the noisy sample`) or `"v_prediction"` (see section 2.4 of [Imagen
+ Video](https://huggingface.co/papers/2210.02303) paper).
+ interpolation_type (`Literal["linear", "log_linear"]`, defaults to `"linear"`, *optional*):
+ The interpolation type to compute intermediate sigmas for the scheduler denoising steps. Should be one of
`"linear"` or `"log_linear"`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
@@ -166,18 +167,26 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
- timestep_spacing (`str`, defaults to `"linspace"`):
+ sigma_min (`float`, *optional*):
+ The minimum sigma value for the noise schedule. If not provided, defaults to the last sigma in the
+ schedule.
+ sigma_max (`float`, *optional*):
+ The maximum sigma value for the noise schedule. If not provided, defaults to the first sigma in the
+ schedule.
+ timestep_spacing (`Literal["linspace", "leading", "trailing"]`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ timestep_type (`Literal["discrete", "continuous"]`, defaults to `"discrete"`):
+ The type of timesteps to use. Can be `"discrete"` or `"continuous"`.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
- final_sigmas_type (`str`, defaults to `"zero"`):
+ final_sigmas_type (`Literal["zero", "sigma_min"]`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
- sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -189,20 +198,20 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- prediction_type: str = "epsilon",
- interpolation_type: str = "linear",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
+ interpolation_type: Literal["linear", "log_linear"] = "linear",
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
- timestep_spacing: str = "linspace",
- timestep_type: str = "discrete", # can be "discrete" or "continuous"
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
+ timestep_type: Literal["discrete", "continuous"] = "discrete",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
- final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
+ final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -259,8 +268,15 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def init_noise_sigma(self):
- # standard deviation of the initial noise distribution
+ def init_noise_sigma(self) -> Union[float, torch.Tensor]:
+ """
+ The standard deviation of the initial noise distribution.
+
+ Returns:
+ `float` or `torch.Tensor`:
+ The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
+ the timestep spacing configuration.
+ """
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
if self.config.timestep_spacing in ["linspace", "trailing"]:
return max_sigma
@@ -268,26 +284,34 @@ def init_noise_sigma(self):
return (max_sigma**2 + 1) ** 0.5
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
- The index counter for current timestep. It will increase 1 after each scheduler step.
+ The index counter for current timestep. It will increase by 1 after each scheduler step.
+
+ Returns:
+ `int` or `None`:
+ The current step index, or `None` if not initialized.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+
+ Returns:
+ `int` or `None`:
+ The begin index for the scheduler, or `None` if not set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -299,13 +323,13 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
Args:
sample (`torch.Tensor`):
- The input sample.
- timestep (`int`, *optional*):
+ The input sample to be scaled.
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
- A scaled input sample.
+ A scaled input sample, divided by `(sigma**2 + 1) ** 0.5`.
"""
if self.step_index is None:
self._init_step_index(timestep)
@@ -318,17 +342,18 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
def set_timesteps(
self,
- num_inference_steps: int = None,
- device: Union[str, torch.device] = None,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
- ):
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model.
+ num_inference_steps (`int`, *optional*):
+ The number of diffusion steps used when generating samples with a pre-trained model. If `None`,
+ `timesteps` or `sigmas` must be provided.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -336,10 +361,9 @@ def set_timesteps(
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
sigmas (`List[float]`, *optional*):
- Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas
- will be generated based on the relevant scheduler attributes. If `sigmas` is passed,
- `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the
- custom sigmas schedule.
+ Custom sigmas used to support arbitrary timesteps schedule. If `None`, timesteps and sigmas will be
+ generated based on the relevant scheduler attributes. If `sigmas` is passed, `num_inference_steps` and
+ `timesteps` must be `None`, and the timesteps will be generated based on the custom sigmas schedule.
"""
if timesteps is not None and sigmas is not None:
@@ -376,7 +400,7 @@ def set_timesteps(
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.float32)
else:
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
@@ -449,7 +473,20 @@ def set_timesteps(
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -473,8 +510,21 @@ def _sigma_to_t(self, sigma, log_sigmas):
return t
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
- def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -500,7 +550,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -523,7 +585,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -551,7 +630,23 @@ def _convert_to_beta(
)
return sigmas
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -565,7 +660,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -591,26 +693,33 @@ def step(
Args:
model_output (`torch.Tensor`):
- The direct output from learned diffusion model.
- timestep (`float`):
+ The direct output from the learned diffusion model.
+ timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- s_churn (`float`):
- s_tmin (`float`):
- s_tmax (`float`):
- s_noise (`float`, defaults to 1.0):
+ s_churn (`float`, *optional*, defaults to `0.0`):
+ Stochasticity parameter that controls the amount of noise added during sampling. Higher values increase
+ randomness.
+ s_tmin (`float`, *optional*, defaults to `0.0`):
+ Minimum timestep threshold for applying stochasticity. Only timesteps above this value will have noise
+ added.
+ s_tmax (`float`, *optional*, defaults to `inf`):
+ Maximum timestep threshold for applying stochasticity. Only timesteps below this value will have noise
+ added.
+ s_noise (`float`, *optional*, defaults to `1.0`):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
- A random number generator.
- return_dict (`bool`):
+ A random number generator for reproducible sampling.
+ return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
- If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
- returned, otherwise a tuple is returned where the first element is the sample tensor.
+ If `return_dict` is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor and the second
+ element is the predicted original sample.
"""
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
@@ -689,6 +798,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -717,6 +841,24 @@ def add_noise(
return noisy_samples
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction for the given sample and noise at the specified timesteps.
+
+ This method implements the velocity prediction used in v-prediction models, which predicts a linear combination
+ of the sample and noise.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample for which to compute the velocity.
+ noise (`torch.Tensor`):
+ The noise tensor corresponding to the sample.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to compute the velocity.
+
+ Returns:
+ `torch.Tensor`:
+ The velocity prediction computed as `sqrt(alpha_prod) * noise - sqrt(1 - alpha_prod) * sample`.
+ """
if (
isinstance(timesteps, int)
or isinstance(timesteps, torch.IntTensor)
@@ -753,5 +895,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py
index 55b0c2460a81..09341c909d2e 100644
--- a/src/diffusers/schedulers/scheduling_euler_discrete_flax.py
+++ b/src/diffusers/schedulers/scheduling_euler_discrete_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -52,8 +52,8 @@ class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput):
class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
- Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
- k-diffusion implementation by Katherine Crowson:
+ Euler scheduler (Algorithm 2) from Karras et al. (2022) https://huggingface.co/papers/2206.00364. . Based on the
+ original k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
@@ -74,7 +74,7 @@ class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
- https://imagen.research.google/video/paper.pdf)
+ https://huggingface.co/papers/2210.02303)
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
index cbb27e5fad63..9fd61d9e18d1 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
+ stochastic_sampling (`bool`, defaults to False):
+ Whether to use stochastic sampling.
"""
_compatibles = []
@@ -101,6 +103,7 @@ def __init__(
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
+ stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -157,7 +160,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -437,13 +440,25 @@ def step(
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
- dt = (per_token_sigmas - lower_sigmas)[..., None]
+
+ current_sigma = per_token_sigmas[..., None]
+ next_sigma = lower_sigmas[..., None]
+ dt = current_sigma - next_sigma
else:
- sigma = self.sigmas[self.step_index]
- sigma_next = self.sigmas[self.step_index + 1]
+ sigma_idx = self.step_index
+ sigma = self.sigmas[sigma_idx]
+ sigma_next = self.sigmas[sigma_idx + 1]
+
+ current_sigma = sigma
+ next_sigma = sigma_next
dt = sigma_next - sigma
- prev_sample = sample + dt * model_output
+ if self.config.stochastic_sampling:
+ x0 = sample - current_sigma * model_output
+ noise = torch.randn_like(sample)
+ prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
+ else:
+ prev_sample = sample + dt * model_output
# upon completion increase step index by one
self._step_index += 1
@@ -458,7 +473,20 @@ def step(
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -484,7 +512,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -508,7 +548,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
diff --git a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
index 2addc5f3eeec..6febee444c5a 100644
--- a/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -102,7 +102,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
diff --git a/src/diffusers/schedulers/scheduling_flow_match_lcm.py b/src/diffusers/schedulers/scheduling_flow_match_lcm.py
new file mode 100644
index 000000000000..25186d1fe969
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_flow_match_lcm.py
@@ -0,0 +1,603 @@
+# Copyright 2025 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, is_scipy_available, logging
+from ..utils.torch_utils import randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+if is_scipy_available():
+ import scipy.stats
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class FlowMatchLCMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ LCM scheduler for Flow Matching.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ shift (`float`, defaults to 1.0):
+ The shift value for the timestep schedule.
+ use_dynamic_shifting (`bool`, defaults to False):
+ Whether to apply timestep shifting on-the-fly based on the image resolution.
+ base_shift (`float`, defaults to 0.5):
+ Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent
+ with desired output.
+ max_shift (`float`, defaults to 1.15):
+ Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be
+ more exaggerated or stylized.
+ base_image_seq_len (`int`, defaults to 256):
+ The base image sequence length.
+ max_image_seq_len (`int`, defaults to 4096):
+ The maximum image sequence length.
+ invert_sigmas (`bool`, defaults to False):
+ Whether to invert the sigmas.
+ shift_terminal (`float`, defaults to None):
+ The end value of the shifted timestep schedule.
+ use_karras_sigmas (`bool`, defaults to False):
+ Whether to use Karras sigmas for step sizes in the noise schedule during sampling.
+ use_exponential_sigmas (`bool`, defaults to False):
+ Whether to use exponential sigmas for step sizes in the noise schedule during sampling.
+ use_beta_sigmas (`bool`, defaults to False):
+ Whether to use beta sigmas for step sizes in the noise schedule during sampling.
+ time_shift_type (`str`, defaults to "exponential"):
+ The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
+ scale_factors ('list', defaults to None)
+ It defines how to scale the latents at which predictions are made.
+ upscale_mode ('str', defaults to 'bicubic')
+ Upscaling method, applied if scale-wise generation is considered
+ """
+
+ _compatibles = []
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ shift: float = 1.0,
+ use_dynamic_shifting: bool = False,
+ base_shift: Optional[float] = 0.5,
+ max_shift: Optional[float] = 1.15,
+ base_image_seq_len: Optional[int] = 256,
+ max_image_seq_len: Optional[int] = 4096,
+ invert_sigmas: bool = False,
+ shift_terminal: Optional[float] = None,
+ use_karras_sigmas: Optional[bool] = False,
+ use_exponential_sigmas: Optional[bool] = False,
+ use_beta_sigmas: Optional[bool] = False,
+ time_shift_type: str = "exponential",
+ scale_factors: Optional[List[float]] = None,
+ upscale_mode: Optional[str] = "bicubic",
+ ):
+ if self.config.use_beta_sigmas and not is_scipy_available():
+ raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
+ if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
+ raise ValueError(
+ "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
+ )
+ if time_shift_type not in {"exponential", "linear"}:
+ raise ValueError("`time_shift_type` must either be 'exponential' or 'linear'.")
+
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
+
+ sigmas = timesteps / num_train_timesteps
+ if not use_dynamic_shifting:
+ # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
+
+ self.timesteps = sigmas * num_train_timesteps
+
+ self._step_index = None
+ self._begin_index = None
+
+ self._shift = shift
+
+ self._init_size = None
+ self._scale_factors = scale_factors
+ self._upscale_mode = upscale_mode
+
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
+ self.sigma_min = self.sigmas[-1].item()
+ self.sigma_max = self.sigmas[0].item()
+
+ @property
+ def shift(self):
+ """
+ The value used for shifting.
+ """
+ return self._shift
+
+ @property
+ def step_index(self):
+ """
+ The index counter for current timestep. It will increase 1 after each scheduler step.
+ """
+ return self._step_index
+
+ @property
+ def begin_index(self):
+ """
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+ """
+ return self._begin_index
+
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
+ def set_begin_index(self, begin_index: int = 0):
+ """
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
+
+ Args:
+ begin_index (`int`, defaults to `0`):
+ The begin index for the scheduler.
+ """
+ self._begin_index = begin_index
+
+ def set_shift(self, shift: float):
+ self._shift = shift
+
+ def set_scale_factors(self, scale_factors: list, upscale_mode):
+ """
+ Sets scale factors for a scale-wise generation regime.
+
+ Args:
+ scale_factors (`list`):
+ The scale factors for each step
+ upscale_mode (`str`):
+ Upscaling method
+ """
+ self._scale_factors = scale_factors
+ self._upscale_mode = upscale_mode
+
+ def scale_noise(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ noise: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ """
+ Forward process in flow-matching
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
+ sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
+
+ if sample.device.type == "mps" and torch.is_floating_point(timestep):
+ # mps does not support float64
+ schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
+ timestep = timestep.to(sample.device, dtype=torch.float32)
+ else:
+ schedule_timesteps = self.timesteps.to(sample.device)
+ timestep = timestep.to(sample.device)
+
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
+ if self.begin_index is None:
+ step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
+ elif self.step_index is not None:
+ # add_noise is called after first denoising step (for inpainting)
+ step_indices = [self.step_index] * timestep.shape[0]
+ else:
+ # add noise is called before first denoising step to create initial latent(img2img)
+ step_indices = [self.begin_index] * timestep.shape[0]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < len(sample.shape):
+ sigma = sigma.unsqueeze(-1)
+
+ sample = sigma * noise + (1.0 - sigma) * sample
+
+ return sample
+
+ def _sigma_to_t(self, sigma):
+ return sigma * self.config.num_train_timesteps
+
+ def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
+ if self.config.time_shift_type == "exponential":
+ return self._time_shift_exponential(mu, sigma, t)
+ elif self.config.time_shift_type == "linear":
+ return self._time_shift_linear(mu, sigma, t)
+
+ def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
+ r"""
+ Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
+ value.
+
+ Reference:
+ https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
+
+ Args:
+ t (`torch.Tensor`):
+ A tensor of timesteps to be stretched and shifted.
+
+ Returns:
+ `torch.Tensor`:
+ A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
+ """
+ one_minus_z = 1 - t
+ scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
+ stretched_t = 1 - (one_minus_z / scale_factor)
+ return stretched_t
+
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device] = None,
+ sigmas: Optional[List[float]] = None,
+ mu: Optional[float] = None,
+ timesteps: Optional[List[float]] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`, *optional*):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ sigmas (`List[float]`, *optional*):
+ Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
+ automatically.
+ mu (`float`, *optional*):
+ Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep
+ shifting.
+ timesteps (`List[float]`, *optional*):
+ Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed
+ automatically.
+ """
+ if self.config.use_dynamic_shifting and mu is None:
+ raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
+
+ if sigmas is not None and timesteps is not None:
+ if len(sigmas) != len(timesteps):
+ raise ValueError("`sigmas` and `timesteps` should have the same length")
+
+ if num_inference_steps is not None:
+ if (sigmas is not None and len(sigmas) != num_inference_steps) or (
+ timesteps is not None and len(timesteps) != num_inference_steps
+ ):
+ raise ValueError(
+ "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided"
+ )
+ else:
+ num_inference_steps = len(sigmas) if sigmas is not None else len(timesteps)
+
+ self.num_inference_steps = num_inference_steps
+
+ # 1. Prepare default sigmas
+ is_timesteps_provided = timesteps is not None
+
+ if is_timesteps_provided:
+ timesteps = np.array(timesteps).astype(np.float32)
+
+ if sigmas is None:
+ if timesteps is None:
+ timesteps = np.linspace(
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
+ )
+ sigmas = timesteps / self.config.num_train_timesteps
+ else:
+ sigmas = np.array(sigmas).astype(np.float32)
+ num_inference_steps = len(sigmas)
+
+ # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
+ # "exponential" or "linear" type is applied
+ if self.config.use_dynamic_shifting:
+ sigmas = self.time_shift(mu, 1.0, sigmas)
+ else:
+ sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
+
+ # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
+ if self.config.shift_terminal:
+ sigmas = self.stretch_shift_to_terminal(sigmas)
+
+ # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
+ if self.config.use_karras_sigmas:
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ elif self.config.use_exponential_sigmas:
+ sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+ elif self.config.use_beta_sigmas:
+ sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
+
+ # 5. Convert sigmas and timesteps to tensors and move to specified device
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
+ if not is_timesteps_provided:
+ timesteps = sigmas * self.config.num_train_timesteps
+ else:
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
+
+ # 6. Append the terminal sigma value.
+ # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
+ # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi
+ if self.config.invert_sigmas:
+ sigmas = 1.0 - sigmas
+ timesteps = sigmas * self.config.num_train_timesteps
+ sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
+ else:
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
+
+ self.timesteps = timesteps
+ self.sigmas = sigmas
+ self._step_index = None
+ self._begin_index = None
+
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
+ if schedule_timesteps is None:
+ schedule_timesteps = self.timesteps
+
+ indices = (schedule_timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ pos = 1 if len(indices) > 1 else 0
+
+ return indices[pos].item()
+
+ def _init_step_index(self, timestep):
+ if self.begin_index is None:
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ self._step_index = self.index_for_timestep(timestep)
+ else:
+ self._step_index = self._begin_index
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[FlowMatchLCMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`):
+ Whether or not to return a [`~schedulers.scheduling_flow_match_lcm.FlowMatchLCMSchedulerOutput`] or
+ tuple.
+
+ Returns:
+ [`~schedulers.scheduling_flow_match_lcm.FlowMatchLCMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_flow_match_lcm.FlowMatchLCMSchedulerOutput`] is
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
+ """
+
+ if (
+ isinstance(timestep, int)
+ or isinstance(timestep, torch.IntTensor)
+ or isinstance(timestep, torch.LongTensor)
+ ):
+ raise ValueError(
+ (
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
+ " `FlowMatchLCMScheduler.step()` is not supported. Make sure to pass"
+ " one of the `scheduler.timesteps` as a timestep."
+ ),
+ )
+
+ if self._scale_factors and self._upscale_mode and len(self.timesteps) != len(self._scale_factors) + 1:
+ raise ValueError(
+ "`_scale_factors` should have the same length as `timesteps` - 1, if `_scale_factors` are set."
+ )
+
+ if self._init_size is None or self.step_index is None:
+ self._init_size = model_output.size()[2:]
+
+ if self.step_index is None:
+ self._init_step_index(timestep)
+
+ # Upcast to avoid precision issues when computing prev_sample
+ sample = sample.to(torch.float32)
+
+ sigma = self.sigmas[self.step_index]
+ sigma_next = self.sigmas[self.step_index + 1]
+ x0_pred = sample - sigma * model_output
+
+ if self._scale_factors and self._upscale_mode:
+ if self._step_index < len(self._scale_factors):
+ size = [round(self._scale_factors[self._step_index] * size) for size in self._init_size]
+ x0_pred = torch.nn.functional.interpolate(x0_pred, size=size, mode=self._upscale_mode)
+
+ noise = randn_tensor(x0_pred.shape, generator=generator, device=x0_pred.device, dtype=x0_pred.dtype)
+ prev_sample = (1 - sigma_next) * x0_pred + sigma_next * noise
+
+ # upon completion increase step index by one
+ self._step_index += 1
+ # Cast sample back to model compatible dtype
+ prev_sample = prev_sample.to(model_output.dtype)
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return FlowMatchLCMSchedulerOutput(prev_sample=prev_sample)
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ rho = 7.0 # 7.0 is the value used in the paper
+ ramp = np.linspace(0, 1, num_inference_steps)
+ min_inv_rho = sigma_min ** (1 / rho)
+ max_inv_rho = sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
+ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
+ return sigmas
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
+ def _convert_to_beta(
+ self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
+ ) -> torch.Tensor:
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
+
+ # Hack to make sure that other schedulers which copy this function don't break
+ # TODO: Add this logic to the other schedulers
+ if hasattr(self.config, "sigma_min"):
+ sigma_min = self.config.sigma_min
+ else:
+ sigma_min = None
+
+ if hasattr(self.config, "sigma_max"):
+ sigma_max = self.config.sigma_max
+ else:
+ sigma_max = None
+
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
+
+ sigmas = np.array(
+ [
+ sigma_min + (ppf * (sigma_max - sigma_min))
+ for ppf in [
+ scipy.stats.beta.ppf(timestep, alpha, beta)
+ for timestep in 1 - np.linspace(0, 1, num_inference_steps)
+ ]
+ ]
+ )
+ return sigmas
+
+ def _time_shift_exponential(self, mu, sigma, t):
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
+
+ def _time_shift_linear(self, mu, sigma, t):
+ return mu / (mu + (1 / t - 1) ** sigma)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py
index cb6cb9e79565..b113f9b49832 100644
--- a/src/diffusers/schedulers/scheduling_heun_discrete.py
+++ b/src/diffusers/schedulers/scheduling_heun_discrete.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+# Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class HeunDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -106,15 +107,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
+ beta_schedule (`"linear"`, `"scaled_linear"`, `"squaredcos_cap_v2"`, or `"exp"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear` or `scaled_linear`.
+ `linear`, `scaled_linear`, `squaredcos_cap_v2`, or `exp`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
clip_sample (`bool`, defaults to `True`):
Clip the predicted sample for numerical stability.
clip_sample_range (`float`, defaults to 1.0):
@@ -127,7 +128,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
- timestep_spacing (`str`, defaults to `"linspace"`):
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
@@ -143,17 +144,17 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "exp"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
clip_sample: Optional[bool] = False,
clip_sample_range: float = 1.0,
- timestep_spacing: str = "linspace",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
- ):
+ ) -> None:
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -187,7 +188,23 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -224,12 +241,12 @@ def begin_index(self):
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -246,7 +263,7 @@ def scale_model_input(
Args:
sample (`torch.Tensor`):
The input sample.
- timestep (`int`, *optional*):
+ timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -266,19 +283,19 @@ def set_timesteps(
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
timesteps: Optional[List[int]] = None,
- ):
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
- num_inference_steps (`int`):
+ num_inference_steps (`int`, *optional*, defaults to `None`):
The number of diffusion steps used when generating samples with a pre-trained model.
- device (`str` or `torch.device`, *optional*):
+ device (`str`, `torch.device`, *optional*, defaults to `None`):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
- num_train_timesteps (`int`, *optional*):
+ num_train_timesteps (`int`, *optional*, defaults to `None`):
The number of diffusion steps used when training the model. If `None`, the default
`num_train_timesteps` attribute is used.
- timesteps (`List[int]`, *optional*):
+ timesteps (`List[int]`, *optional*, defaults to `None`):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
must be `None`, and `timestep_spacing` attribute will be ignored.
@@ -301,7 +318,7 @@ def set_timesteps(
if timesteps is not None:
timesteps = np.array(timesteps, dtype=np.float32)
else:
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading":
@@ -353,7 +370,20 @@ def set_timesteps(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -377,8 +407,21 @@ def _sigma_to_t(self, sigma, log_sigmas):
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
- def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -404,7 +447,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -428,7 +483,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -461,7 +533,14 @@ def state_in_first_order(self):
return self.dt is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -579,6 +658,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -606,5 +700,5 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py
index 28f349ae2114..da188fe8297c 100644
--- a/src/diffusers/schedulers/scheduling_ipndm.py
+++ b/src/diffusers/schedulers/scheduling_ipndm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -49,7 +49,7 @@ def __init__(
self.init_noise_sigma = 1.0
# For now we only support F-PNDM, i.e. the runge-kutta method
- # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # For more information on the algorithm please take a look at the paper: https://huggingface.co/papers/2202.09778
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
@@ -78,7 +78,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -112,7 +112,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -127,7 +143,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
index 4b388b4d75b3..da40bed635e1 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+# Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -50,10 +50,10 @@ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -61,16 +61,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -124,7 +125,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -206,7 +207,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -260,7 +261,7 @@ def set_timesteps(
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading":
@@ -342,6 +343,19 @@ def set_timesteps(
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -366,7 +380,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -392,7 +419,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -416,7 +455,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -449,7 +505,23 @@ def state_in_first_order(self):
return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -464,7 +536,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -586,6 +665,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
index a2e564e70a0e..6dc08d4d0a86 100644
--- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
+++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
+# Copyright 2025 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -49,10 +49,10 @@ class KDPM2DiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -60,16 +60,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -123,7 +124,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -206,7 +207,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -260,7 +261,7 @@ def set_timesteps(
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
elif self.config.timestep_spacing == "leading":
@@ -330,7 +331,23 @@ def state_in_first_order(self):
return self.sample is None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -345,7 +362,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -355,6 +379,19 @@ def _init_step_index(self, timestep):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -379,7 +416,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -405,7 +455,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -429,7 +491,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -558,6 +637,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py
index 0d387b53ac3e..bacfbd61006d 100644
--- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py
+++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 NVIDIA and The HuggingFace Team. All rights reserved.
+# Copyright 2025 NVIDIA and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,8 +63,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
- https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
- differential equations." https://arxiv.org/abs/2011.13456
+ https://huggingface.co/papers/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://huggingface.co/papers/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
@@ -72,8 +72,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
[`~SchedulerMixin.from_pretrained`] functions.
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
- Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
- optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
+ Diffusion-Based Generative Models." https://huggingface.co/papers/2206.00364. The grid search values used to find
+ the optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
Args:
sigma_min (`float`): minimum noise magnitude
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index 686b686f6870..0527f3533851 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -51,10 +51,10 @@ class LCMSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -62,16 +62,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -97,15 +98,15 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -169,7 +170,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -251,7 +252,23 @@ def __init__(
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -266,7 +283,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -291,7 +315,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -315,13 +339,23 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -413,8 +447,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
@@ -598,6 +631,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -620,6 +669,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -643,6 +707,17 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
+ """
+ Compute the previous timestep in the diffusion chain.
+
+ Args:
+ timestep (`int`):
+ The current timestep.
+
+ Returns:
+ `int`:
+ The previous timestep.
+ """
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py
index bcf9d9b59e11..276af6eeacb7 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
import warnings
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import scipy.stats
@@ -47,10 +47,10 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -58,16 +58,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -98,15 +99,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- num_train_timesteps (`int`, defaults to 1000):
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
+ beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
+ beta_end (`float`, defaults to `0.02`):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear` or `scaled_linear`.
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -117,14 +117,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
- timestep_spacing (`str`, defaults to `"linspace"`):
+ Video](https://huggingface.co/papers/2210.02303) paper).
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- steps_offset (`int`, defaults to 0):
+ steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
"""
@@ -137,13 +137,13 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
- prediction_type: str = "epsilon",
- timestep_spacing: str = "linspace",
+ prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -182,7 +182,15 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def init_noise_sigma(self):
+ def init_noise_sigma(self) -> Union[float, torch.Tensor]:
+ """
+ The standard deviation of the initial noise distribution.
+
+ Returns:
+ `float` or `torch.Tensor`:
+ The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
+ the timestep spacing configuration.
+ """
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
@@ -190,26 +198,34 @@ def init_noise_sigma(self):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
- The index counter for current timestep. It will increase 1 after each scheduler step.
+ The index counter for current timestep. It will increase by 1 after each scheduler step.
+
+ Returns:
+ `int` or `None`:
+ The current step index, or `None` if not initialized.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
+
+ Returns:
+ `int` or `None`:
+ The begin index for the scheduler, or `None` if not set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -238,14 +254,21 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample
- def get_lms_coefficient(self, order, t, current_order):
+ def get_lms_coefficient(self, order: int, t: int, current_order: int) -> float:
"""
Compute the linear multistep coefficient.
Args:
- order ():
- t ():
- current_order ():
+ order (`int`):
+ The order of the linear multistep method.
+ t (`int`):
+ The current timestep index.
+ current_order (`int`):
+ The current order for which to compute the coefficient.
+
+ Returns:
+ `float`:
+ The computed linear multistep coefficient.
"""
def lms_derivative(tau):
@@ -260,7 +283,7 @@ def lms_derivative(tau):
return integrated_coeff
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -272,7 +295,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
"""
self.num_inference_steps = num_inference_steps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[
::-1
@@ -319,7 +342,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
self.derivatives = []
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -334,7 +373,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -343,7 +389,20 @@ def _init_step_index(self, timestep):
self._step_index = self._begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -366,9 +425,19 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t
- # copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
@@ -382,7 +451,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -406,7 +487,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -521,6 +619,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise tensor to add to the original samples.
+ timesteps (`torch.Tensor`):
+ The timesteps at which to add noise, determining the noise level from the schedule.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples with added noise scaled according to the timestep schedule.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -548,5 +661,5 @@ def add_noise(
noisy_samples = original_samples + noise * sigma
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
index f1169cc90a7b..3fd4dc8a5d61 100644
--- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
+++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -77,7 +77,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
- https://imagen.research.google/video/paper.pdf)
+ https://huggingface.co/papers/2210.02303)
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py
index a05e71c3c225..651532b06ddb 100644
--- a/src/diffusers/schedulers/scheduling_pndm.py
+++ b/src/diffusers/schedulers/scheduling_pndm.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -26,10 +26,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -37,16 +37,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -78,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
- num_train_timesteps (`int`, defaults to 1000):
+ num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
- beta_start (`float`, defaults to 0.0001):
+ beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
- beta_end (`float`, defaults to 0.02):
+ beta_end (`float`, defaults to `0.02`):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
- The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
- `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
skip_prk_steps (`bool`, defaults to `False`):
@@ -96,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
- or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
- paper).
- timestep_spacing (`str`, defaults to `"leading"`):
+ or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
- steps_offset (`int`, defaults to 0):
+ steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
"""
@@ -116,12 +115,12 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
- prediction_type: str = "epsilon",
- timestep_spacing: str = "leading",
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0,
):
if trained_betas is not None:
@@ -146,7 +145,7 @@ def __init__(
self.init_noise_sigma = 1.0
# For now we only support F-PNDM, i.e. the runge-kutta method
- # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # For more information on the algorithm please take a look at the paper: https://huggingface.co/papers/2202.09778
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
@@ -163,7 +162,7 @@ def __init__(
self.plms_timesteps = None
self.timesteps = None
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -175,7 +174,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
"""
self.num_inference_steps = num_inference_steps
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
self._timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round().astype(np.int64)
@@ -242,7 +241,7 @@ def step(
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
@@ -275,14 +274,13 @@ def step_prk(
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
-
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -334,14 +332,13 @@ def step_plms(
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
-
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -402,19 +399,27 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens
"""
return sample
- def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
- # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
- # this function computes x_(t−δ) using the formula of (9)
- # Note that x_t needs to be added to both sides of the equation
-
- # Notation ( ->
- # alpha_prod_t -> α_t
- # alpha_prod_t_prev -> α_(t−δ)
- # beta_prod_t -> (1 - α_t)
- # beta_prod_t_prev -> (1 - α_(t−δ))
- # sample -> x_t
- # model_output -> e_θ(x_t, t)
- # prev_sample -> x_(t−δ)
+ def _get_prev_sample(
+ self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
+ ) -> torch.Tensor:
+ """
+ Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM
+ paper](https://huggingface.co/papers/2202.09778).
+
+ Args:
+ sample (`torch.Tensor`):
+ The current sample x_t.
+ timestep (`int`):
+ The current timestep t.
+ prev_timestep (`int`):
+ The previous timestep (t-δ).
+ model_output (`torch.Tensor`):
+ The model output e_θ(x_t, t).
+
+ Returns:
+ `torch.Tensor`:
+ The previous sample x_(t-δ).
+ """
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -452,6 +457,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -472,5 +493,5 @@ def add_noise(
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py
index 3ac3ba5ca1ba..44bafccd5520 100644
--- a/src/diffusers/schedulers/scheduling_pndm_flax.py
+++ b/src/diffusers/schedulers/scheduling_pndm_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Zhejiang University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -80,7 +80,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
- For more details, see the original paper: https://arxiv.org/abs/2202.09778
+ For more details, see the original paper: https://huggingface.co/papers/2202.09778
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
@@ -103,7 +103,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
- https://imagen.research.google/video/paper.pdf)
+ https://huggingface.co/papers/2210.02303)
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
@@ -134,7 +134,7 @@ def __init__(
self.dtype = dtype
# For now we only support F-PNDM, i.e. the runge-kutta method
- # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # For more information on the algorithm please take a look at the paper: https://huggingface.co/papers/2202.09778
# mainly at formula (9), (12), (13) and the Algorithm 2.
self.pndm_order = 4
@@ -452,7 +452,7 @@ def step_plms(
return (prev_sample, state)
def _get_prev_sample(self, state: PNDMSchedulerState, sample, timestep, prev_timestep, model_output):
- # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778
# this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py
index a14797b42f7a..a2eaf8eb3abd 100644
--- a/src/diffusers/schedulers/scheduling_repaint.py
+++ b/src/diffusers/schedulers/scheduling_repaint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -45,10 +45,10 @@ class RePaintSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -56,16 +56,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -233,10 +234,10 @@ def _get_variance(self, t):
beta_prod_t_prev = 1 - alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from
- # https://arxiv.org/pdf/2006.11239.pdf) and sample from it to get
+ # https://huggingface.co/papers/2006.11239) and sample from it to get
# previous sample x_{t-1} ~ N(pred_prev_sample, variance) == add
# variance to pred_sample
- # Is equivalent to formula (16) in https://arxiv.org/pdf/2010.02502.pdf
+ # Is equivalent to formula (16) in https://huggingface.co/papers/2010.02502
# without eta.
# variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
@@ -288,7 +289,7 @@ def step(
beta_prod_t = 1 - alpha_prod_t
# 2. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ # "predicted x_0" of formula (15) from https://huggingface.co/papers/2006.11239
pred_original_sample = (sample - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
# 3. Clip "predicted x_0"
@@ -312,20 +313,20 @@ def step(
variance = std_dev_t * noise
# 6. compute "direction pointing to x_t" of formula (12)
- # from https://arxiv.org/pdf/2010.02502.pdf
+ # from https://huggingface.co/papers/2010.02502
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * model_output
- # 7. compute x_{t-1} of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # 7. compute x_{t-1} of formula (12) from https://huggingface.co/papers/2010.02502
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
- # 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
+ # 8. Algorithm 1 Line 5 https://huggingface.co/papers/2201.09865
# The computation reported in Algorithm 1 Line 5 is incorrect. Line 5 refers to formula (8a) of the same paper,
# which tells to sample from a Gaussian distribution with mean "(alpha_prod_t_prev**0.5) * original_image"
# and variance "(1 - alpha_prod_t_prev)". This means that the standard Gaussian distribution "noise" should be
# scaled by the square root of the variance (as it is done here), however Algorithm 1 Line 5 tells to scale by the variance.
prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
- # 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
+ # 9. Algorithm 1 Line 8 https://huggingface.co/papers/2201.09865
pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part
if not return_dict:
@@ -348,7 +349,7 @@ def undo_step(self, sample, timestep, generator=None):
else:
noise = randn_tensor(sample.shape, generator=generator, device=sample.device, dtype=sample.dtype)
- # 10. Algorithm 1 Line 10 https://arxiv.org/pdf/2201.09865.pdf
+ # 10. Algorithm 1 Line 10 https://huggingface.co/papers/2201.09865
sample = (1 - beta) ** 0.5 * sample + beta**0.5 * noise
return sample
diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py
index d45c93880bc5..5783e20de69d 100644
--- a/src/diffusers/schedulers/scheduling_sasolver.py
+++ b/src/diffusers/schedulers/scheduling_sasolver.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Shuchen Xue, etc. in University of Chinese Academy of Sciences Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Shuchen Xue, etc. in University of Chinese Academy of Sciences Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# DISCLAIMER: check https://arxiv.org/abs/2309.05019
+# DISCLAIMER: check https://huggingface.co/papers/2309.05019
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math
-from typing import Callable, List, Optional, Tuple, Union
+from typing import Callable, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -33,10 +33,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -44,16 +44,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -104,12 +105,12 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
tau_func (`Callable`, *optional*):
Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`.
SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample
from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check
- https://arxiv.org/abs/2309.05019
+ https://huggingface.co/papers/2309.05019
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -253,7 +254,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -273,7 +274,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item()
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
@@ -342,13 +343,23 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -375,6 +386,19 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -399,6 +423,17 @@ def _sigma_to_t(self, sigma, log_sigmas):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
+ """
+ Convert sigma values to alpha_t and sigma_t values.
+
+ Args:
+ sigma (`torch.Tensor`):
+ The sigma value(s) to convert.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ A tuple containing (alpha_t, sigma_t) values.
+ """
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -410,7 +445,20 @@ def _sigma_to_alpha_sigma_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -436,7 +484,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -460,7 +520,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -500,12 +577,8 @@ def convert_model_output(
Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
designed to discretize an integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both
- noise prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction
+ for both > noise prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
@@ -522,7 +595,7 @@ def convert_model_output(
if len(args) > 1:
sample = args[1]
else:
- raise ValueError("missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -812,22 +885,22 @@ def stochastic_adams_bashforth_update(
if len(args) > 1:
sample = args[1]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if noise is None:
if len(args) > 2:
noise = args[2]
else:
- raise ValueError(" missing `noise` as a required keyward argument")
+ raise ValueError("missing `noise` as a required keyword argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
- raise ValueError(" missing `order` as a required keyward argument")
+ raise ValueError("missing `order` as a required keyword argument")
if tau is None:
if len(args) > 4:
tau = args[4]
else:
- raise ValueError(" missing `tau` as a required keyward argument")
+ raise ValueError("missing `tau` as a required keyword argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
@@ -943,27 +1016,27 @@ def stochastic_adams_moulton_update(
if len(args) > 1:
last_sample = args[1]
else:
- raise ValueError(" missing`last_sample` as a required keyward argument")
+ raise ValueError("missing `last_sample` as a required keyword argument")
if last_noise is None:
if len(args) > 2:
last_noise = args[2]
else:
- raise ValueError(" missing`last_noise` as a required keyward argument")
+ raise ValueError("missing `last_noise` as a required keyword argument")
if this_sample is None:
if len(args) > 3:
this_sample = args[3]
else:
- raise ValueError(" missing`this_sample` as a required keyward argument")
+ raise ValueError("missing `this_sample` as a required keyword argument")
if order is None:
if len(args) > 4:
order = args[4]
else:
- raise ValueError(" missing`order` as a required keyward argument")
+ raise ValueError("missing `order` as a required keyword argument")
if tau is None:
if len(args) > 5:
tau = args[5]
else:
- raise ValueError(" missing`tau` as a required keyward argument")
+ raise ValueError("missing `tau` as a required keyword argument")
if this_timestep is not None:
deprecate(
"this_timestep",
@@ -1041,7 +1114,22 @@ def stochastic_adams_moulton_update(
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index for a given timestep in the schedule.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The timestep for which to find the index.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule.
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -1064,6 +1152,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -1197,6 +1289,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py
index 23f47f42a302..7b01d886299c 100644
--- a/src/diffusers/schedulers/scheduling_scm.py
+++ b/src/diffusers/schedulers/scheduling_scm.py
@@ -1,4 +1,4 @@
-# # Copyright 2024 Sana-Sprint Authors and The HuggingFace Team. All rights reserved.
+# # Copyright 2025 Sana-Sprint Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -109,7 +109,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -168,13 +168,19 @@ def set_timesteps(
else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
- print(f"Set timesteps: {self.timesteps}")
self._step_index = None
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -183,7 +189,23 @@ def _init_step_index(self, timestep):
self._step_index = self._begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py
index cedfbf7d6ad5..1bfc08cce5e9 100644
--- a/src/diffusers/schedulers/scheduling_sde_ve.py
+++ b/src/diffusers/schedulers/scheduling_sde_ve.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py
index 0a8d45d4acbc..09cd081462b3 100644
--- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py
+++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Google Brain and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,7 +61,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
The variance exploding stochastic differential equation (SDE) scheduler.
- For more information, see the original paper: https://arxiv.org/abs/2011.13456
+ For more information, see the original paper: https://huggingface.co/papers/2011.13456
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py
index 5d60383142a4..7b4840ffdb19 100644
--- a/src/diffusers/schedulers/scheduling_tcd.py
+++ b/src/diffusers/schedulers/scheduling_tcd.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -50,10 +50,10 @@ class TCDSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -61,16 +61,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -96,15 +97,15 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -170,7 +171,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ Video](https://huggingface.co/papers/2210.02303) paper).
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -252,7 +253,23 @@ def __init__(
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[float, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index of a given timestep in the timestep schedule.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The timestep value to find in the schedule.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule. For the very first step, returns the second index if
+ multiple matches exist to avoid skipping a sigma when starting mid-schedule (e.g., for image-to-image).
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -267,7 +284,14 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[float, torch.Tensor]) -> None:
+ """
+ Initialize the step index for the scheduler based on the given timestep.
+
+ Args:
+ timestep (`float` or `torch.Tensor`):
+ The current timestep to initialize the step index from.
+ """
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -292,7 +316,7 @@ def set_begin_index(self, begin_index: int = 0):
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
@@ -316,6 +340,24 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler._get_variance
def _get_variance(self, timestep, prev_timestep):
+ """
+ Computes the variance of the noise added at a given diffusion step.
+
+ For a given `timestep` and its previous step, this method calculates the variance as defined in DDIM/DDPM
+ literature:
+ var_t = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+ where alpha_prod and beta_prod are cumulative products of alphas and betas, respectively.
+
+ Args:
+ timestep (`int`):
+ The current timestep in the diffusion process.
+ prev_timestep (`int`):
+ The previous timestep in the diffusion process. If negative, uses `final_alpha_cumprod`.
+
+ Returns:
+ `torch.Tensor`:
+ The variance for the current timestep.
+ """
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -328,13 +370,23 @@ def _get_variance(self, timestep, prev_timestep):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -431,8 +483,7 @@ def set_timesteps(
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
- f"`timesteps` must start before `self.config.train_timesteps`:"
- f" {self.config.num_train_timesteps}."
+ f"`timesteps` must start before `self.config.train_timesteps`: {self.config.num_train_timesteps}."
)
# Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
@@ -635,6 +686,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
@@ -657,6 +724,21 @@ def add_noise(
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
+ """
+ Compute the velocity prediction from the sample and noise according to the velocity formula.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ noise (`torch.Tensor`):
+ The noise tensor.
+ timesteps (`torch.IntTensor`):
+ The timesteps for velocity computation.
+
+ Returns:
+ `torch.Tensor`:
+ The computed velocity.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as sample
self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -680,6 +762,17 @@ def __len__(self):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
+ """
+ Compute the previous timestep in the diffusion chain.
+
+ Args:
+ timestep (`int`):
+ The current timestep.
+
+ Returns:
+ `int`:
+ The previous timestep.
+ """
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py
index 22a53b0e73b6..5a978dec649b 100644
--- a/src/diffusers/schedulers/scheduling_unclip.py
+++ b/src/diffusers/schedulers/scheduling_unclip.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Kakao Brain and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -46,10 +46,10 @@ class UnCLIPSchedulerOutput(BaseOutput):
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -57,16 +57,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -191,7 +192,7 @@ def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance
else:
beta = 1 - alpha_prod_t / alpha_prod_t_prev
- # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://huggingface.co/papers/2006.11239)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = beta_prod_t_prev / beta_prod_t * beta
@@ -266,7 +267,7 @@ def step(
alpha = 1 - beta
# 2. compute predicted original sample from predicted noise also called
- # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ # "predicted x_0" of formula (15) from https://huggingface.co/papers/2006.11239
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
@@ -284,12 +285,12 @@ def step(
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t
current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
- # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ # See formula (7) from https://huggingface.co/papers/2006.11239
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
@@ -334,6 +335,22 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise magnitude at each timestep (this is the forward
+ diffusion process).
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples to which noise will be added.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps indicating the noise level for each sample.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
# for the subsequent add_noise calls
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 01500426305c..689c6a06350b 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -1,4 +1,4 @@
-# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info
+# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
import math
-from typing import List, Optional, Tuple, Union
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,10 +32,10 @@
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
- num_diffusion_timesteps,
- max_beta=0.999,
- alpha_transform_type="cosine",
-):
+ num_diffusion_timesteps: int,
+ max_beta: float = 0.999,
+ alpha_transform_type: Literal["cosine", "exp"] = "cosine",
+) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
@@ -43,16 +43,17 @@ def betas_for_alpha_bar(
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
-
Args:
- num_diffusion_timesteps (`int`): the number of betas to produce.
- max_beta (`float`): the maximum beta to use; use values lower than 1 to
- prevent singularities.
- alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
- Choose from `cosine` or `exp`
+ num_diffusion_timesteps (`int`):
+ The number of betas to produce.
+ max_beta (`float`, defaults to `0.999`):
+ The maximum beta to use; use values lower than 1 to avoid numerical instability.
+ alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
+ The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
- betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ `torch.Tensor`:
+ The betas used by the scheduler to step the model outputs.
"""
if alpha_transform_type == "cosine":
@@ -76,17 +77,17 @@ def alpha_bar_fn(t):
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
-def rescale_zero_terminal_snr(betas):
+def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
- Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
-
+ Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
betas (`torch.Tensor`):
- the betas that the scheduler is being initialized with.
+ The betas that the scheduler is being initialized with.
Returns:
- `torch.Tensor`: rescaled betas with zero terminal SNR
+ `torch.Tensor`:
+ Rescaled betas with zero terminal SNR.
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
@@ -126,19 +127,19 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
- beta_schedule (`str`, defaults to `"linear"`):
+ beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
- solver_order (`int`, default `2`):
+ solver_order (`int`, defaults to `2`):
The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1`
due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for
unconditional sampling.
- prediction_type (`str`, defaults to `epsilon`, *optional*):
+ prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
- `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
- Video](https://imagen.research.google/video/paper.pdf) paper).
+ `sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
+ Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -148,7 +149,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`.
predict_x0 (`bool`, defaults to `True`):
Whether to use the updating algorithm on the predicted x0.
- solver_type (`str`, default `bh2`):
+ solver_type (`"bh1"` or `"bh2"`, defaults to `"bh2"`):
Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2`
otherwise.
lower_order_final (`bool`, default `True`):
@@ -168,12 +169,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
- timestep_spacing (`str`, defaults to `"linspace"`):
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
+ timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
- final_sigmas_type (`str`, defaults to `"zero"`):
+ final_sigmas_type (`"zero"` or `"sigma_min"`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
rescale_betas_zero_snr (`bool`, defaults to `False`):
@@ -191,28 +194,30 @@ def __init__(
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
- beta_schedule: str = "linear",
+ beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
- prediction_type: str = "epsilon",
+ prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
predict_x0: bool = True,
- solver_type: str = "bh2",
+ solver_type: Literal["bh1", "bh2"] = "bh2",
lower_order_final: bool = True,
disable_corrector: List[int] = [],
- solver_p: SchedulerMixin = None,
+ solver_p: Optional[SchedulerMixin] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
- timestep_spacing: str = "linspace",
+ timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
steps_offset: int = 0,
- final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
rescale_betas_zero_snr: bool = False,
- ):
+ use_dynamic_shifting: bool = False,
+ time_shift_type: Literal["exponential"] = "exponential",
+ ) -> None:
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
@@ -274,31 +279,33 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
- def step_index(self):
+ def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
- def begin_index(self):
+ def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
- def set_begin_index(self, begin_index: int = 0):
+ def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
- begin_index (`int`):
+ begin_index (`int`, defaults to `0`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(
+ self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
+ ) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -307,8 +314,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ mu (`float`, *optional*):
+ Optional mu parameter for dynamic shifting when using exponential time shift type.
"""
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
+ if mu is not None:
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
+ self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
@@ -423,13 +435,23 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""
+ Apply dynamic thresholding to the predicted sample.
+
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
photorealism as well as better image-text alignment, especially when using very large guidance weights."
- https://arxiv.org/abs/2205.11487
+ https://huggingface.co/papers/2205.11487
+
+ Args:
+ sample (`torch.Tensor`):
+ The predicted sample to be thresholded.
+
+ Returns:
+ `torch.Tensor`:
+ The thresholded sample.
"""
dtype = sample.dtype
batch_size, channels, *remaining_dims = sample.shape
@@ -455,7 +477,20 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
- def _sigma_to_t(self, sigma, log_sigmas):
+ def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
+ """
+ Convert sigma values to corresponding timestep values through interpolation.
+
+ Args:
+ sigma (`np.ndarray`):
+ The sigma value(s) to convert to timestep(s).
+ log_sigmas (`np.ndarray`):
+ The logarithm of the sigma schedule used for interpolation.
+
+ Returns:
+ `np.ndarray`:
+ The interpolated timestep value(s) corresponding to the input sigma(s).
+ """
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))
@@ -479,7 +514,18 @@ def _sigma_to_t(self, sigma, log_sigmas):
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
- def _sigma_to_alpha_sigma_t(self, sigma):
+ def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convert sigma values to alpha_t and sigma_t values.
+
+ Args:
+ sigma (`torch.Tensor`):
+ The sigma value(s) to convert.
+
+ Returns:
+ `Tuple[torch.Tensor, torch.Tensor]`:
+ A tuple containing (alpha_t, sigma_t) values.
+ """
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
@@ -490,8 +536,21 @@ def _sigma_to_alpha_sigma_t(self, sigma):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
- def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
- """Constructs the noise schedule of Karras et al. (2022)."""
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
+ """
+ Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
+ Models](https://huggingface.co/papers/2206.00364).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following the Karras noise schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -517,7 +576,19 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
- """Constructs an exponential noise schedule."""
+ """
+ Construct an exponential noise schedule.
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following an exponential schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -541,7 +612,24 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
- """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
+ """
+ Construct a beta noise schedule as proposed in [Beta Sampling is All You
+ Need](https://huggingface.co/papers/2407.12173).
+
+ Args:
+ in_sigmas (`torch.Tensor`):
+ The input sigma values to be converted.
+ num_inference_steps (`int`):
+ The number of inference steps to generate the noise schedule for.
+ alpha (`float`, *optional*, defaults to `0.6`):
+ The alpha parameter for the beta distribution.
+ beta (`float`, *optional*, defaults to `0.6`):
+ The beta parameter for the beta distribution.
+
+ Returns:
+ `torch.Tensor`:
+ The converted sigma values following a beta distribution schedule.
+ """
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
@@ -596,7 +684,7 @@ def convert_model_output(
if len(args) > 1:
sample = args[1]
else:
- raise ValueError("missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
@@ -672,12 +760,12 @@ def multistep_uni_p_bh_update(
if len(args) > 1:
sample = args[1]
else:
- raise ValueError(" missing `sample` as a required keyward argument")
+ raise ValueError("missing `sample` as a required keyword argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
- raise ValueError(" missing `order` as a required keyward argument")
+ raise ValueError("missing `order` as a required keyword argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
@@ -804,17 +892,17 @@ def multistep_uni_c_bh_update(
if len(args) > 1:
last_sample = args[1]
else:
- raise ValueError(" missing`last_sample` as a required keyward argument")
+ raise ValueError("missing `last_sample` as a required keyword argument")
if this_sample is None:
if len(args) > 2:
this_sample = args[2]
else:
- raise ValueError(" missing`this_sample` as a required keyward argument")
+ raise ValueError("missing `this_sample` as a required keyword argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
- raise ValueError(" missing`order` as a required keyward argument")
+ raise ValueError("missing `order` as a required keyword argument")
if this_timestep is not None:
deprecate(
"this_timestep",
@@ -909,7 +997,22 @@ def multistep_uni_c_bh_update(
return x_t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
- def index_for_timestep(self, timestep, schedule_timesteps=None):
+ def index_for_timestep(
+ self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
+ ) -> int:
+ """
+ Find the index for a given timestep in the schedule.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The timestep for which to find the index.
+ schedule_timesteps (`torch.Tensor`, *optional*):
+ The timestep schedule to search in. If `None`, uses `self.timesteps`.
+
+ Returns:
+ `int`:
+ The index of the timestep in the schedule.
+ """
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -929,9 +1032,13 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
return step_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
- def _init_step_index(self, timestep):
+ def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.
+
+ Args:
+ timestep (`int` or `torch.Tensor`):
+ The current timestep for which to initialize the step index.
"""
if self.begin_index is None:
@@ -955,11 +1062,11 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
- timestep (`int`):
+ timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
- return_dict (`bool`):
+ return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
@@ -1044,6 +1151,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
+ """
+ Add noise to the original samples according to the noise schedule at the specified timesteps.
+
+ Args:
+ original_samples (`torch.Tensor`):
+ The original samples without noise.
+ noise (`torch.Tensor`):
+ The noise to add to the samples.
+ timesteps (`torch.IntTensor`):
+ The timesteps at which to add noise to the samples.
+
+ Returns:
+ `torch.Tensor`:
+ The noisy samples.
+ """
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -1072,5 +1194,5 @@ def add_noise(
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
- def __len__(self):
+ def __len__(self) -> int:
return self.config.num_train_timesteps
diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py
index 83f31b72c10b..a355c7bb1a51 100644
--- a/src/diffusers/schedulers/scheduling_utils.py
+++ b/src/diffusers/schedulers/scheduling_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -138,15 +138,11 @@ def from_pretrained(
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
"""
config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py
index ae11baf9ea1b..0534e47d8a30 100644
--- a/src/diffusers/schedulers/scheduling_utils_flax.py
+++ b/src/diffusers/schedulers/scheduling_utils_flax.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,9 +22,11 @@
import jax.numpy as jnp
from huggingface_hub.utils import validate_hf_hub_args
-from ..utils import BaseOutput, PushToHubMixin
+from ..utils import BaseOutput, PushToHubMixin, logging
+logger = logging.get_logger(__name__)
+
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@@ -118,21 +120,18 @@ def from_pretrained(
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
-
-
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
+ > [!TIP] > It is required to be logged in (`hf auth login`) when you want to use private or [gated >
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
- Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
- use this method in a firewalled environment.
-
-
+ > [!TIP] > Activate the special
+ ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to > use this method in a
+ firewalled environment.
"""
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py
index bd8d255fa901..57306301d023 100644
--- a/src/diffusers/schedulers/scheduling_vq_diffusion.py
+++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py
@@ -1,4 +1,4 @@
-# Copyright 2024 Microsoft and The HuggingFace Team. All rights reserved.
+# Copyright 2025 Microsoft and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index c570bac733db..7a98fa3da14a 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -3,12 +3,16 @@
import gc
import math
import random
+import re
+import warnings
+from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from .models import UNet2DConditionModel
+from .pipelines import DiffusionPipeline
from .schedulers import SchedulerMixin
from .utils import (
convert_state_dict_to_diffusers,
@@ -149,9 +153,9 @@ def compute_dream_and_update_latents(
dream_detail_preservation: float = 1.0,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
- Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
- DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
- forward step without gradients.
+ Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from
+ https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more
+ efficient and accurate at the cost of an extra forward step without gradients.
Args:
`unet`: The state unet to use to make a prediction.
@@ -241,12 +245,20 @@ def _set_state_dict_into_text_encoder(
"""
text_encoder_state_dict = {
- f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
+ f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
}
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
+def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
+ metadatas = {}
+ for module_name, module in modules_to_save.items():
+ if module is not None:
+ metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
+ return metadatas
+
+
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
@@ -261,7 +273,7 @@ def compute_density_for_timestep_sampling(
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
- SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
@@ -280,7 +292,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
- SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
@@ -308,6 +320,80 @@ def free_memory():
torch.xpu.empty_cache()
+@contextmanager
+def offload_models(
+ *modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
+):
+ """
+ Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
+ device on exit.
+
+ Args:
+ device (`str` or `torch.Device`): Device to move the `modules` to.
+ offload (`bool`): Flag to enable offloading.
+ """
+ if offload:
+ is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
+ # record where each module was
+ if is_model:
+ original_devices = [next(m.parameters()).device for m in modules]
+ else:
+ assert len(modules) == 1
+ # For DiffusionPipeline, wrap the device in a list to make it iterable
+ original_devices = [modules[0].device]
+ # move to target device
+ for m in modules:
+ m.to(device)
+
+ try:
+ yield
+ finally:
+ if offload:
+ # move back to original devices
+ for m, orig_dev in zip(modules, original_devices):
+ m.to(orig_dev)
+
+
+def parse_buckets_string(buckets_str):
+ """Parses a string defining buckets into a list of (height, width) tuples."""
+ if not buckets_str:
+ raise ValueError("Bucket string cannot be empty.")
+
+ bucket_pairs = buckets_str.strip().split(";")
+ parsed_buckets = []
+ for pair_str in bucket_pairs:
+ match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
+ if not match:
+ raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.")
+ try:
+ height = int(match.group(1))
+ width = int(match.group(2))
+ if height <= 0 or width <= 0:
+ raise ValueError("Bucket dimensions must be positive integers.")
+ if height % 8 != 0 or width % 8 != 0:
+ warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.")
+ parsed_buckets.append((height, width))
+ except ValueError as e:
+ raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e
+
+ if not parsed_buckets:
+ raise ValueError("No valid buckets found in the provided string.")
+
+ return parsed_buckets
+
+
+def find_nearest_bucket(h, w, bucket_options):
+ """Finds the closes bucket to the given height and width."""
+ min_metric = float("inf")
+ best_bucket_idx = None
+ for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
+ metric = abs(h * bucket_w - w * bucket_h)
+ if metric <= min_metric:
+ min_metric = metric
+ best_bucket_idx = bucket_idx
+ return best_bucket_idx
+
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
@@ -583,7 +669,7 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
if self.temp_stored_params is None:
- raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
+ raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 50a470772772..6884d3be9292 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -20,10 +20,12 @@
from .. import __version__
from .constants import (
CONFIG_NAME,
+ DEFAULT_HF_PARALLEL_LOADING_WORKERS,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
+ HF_ENABLE_PARALLEL_LOADING,
HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION,
@@ -36,7 +38,7 @@
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
-from .deprecation_utils import deprecate
+from .deprecation_utils import _maybe_remap_transformers_class, deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
@@ -62,33 +64,51 @@
get_objects_from_module,
is_accelerate_available,
is_accelerate_version,
+ is_aiter_available,
+ is_aiter_version,
+ is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_bs4_available,
+ is_cosmos_guardrail_available,
+ is_flash_attn_3_available,
+ is_flash_attn_available,
+ is_flash_attn_version,
is_flax_available,
is_ftfy_available,
is_gguf_available,
is_gguf_version,
is_google_colab,
is_hf_hub_version,
+ is_hpu_available,
is_inflect_available,
is_invisible_watermark_available,
is_k_diffusion_available,
is_k_diffusion_version,
+ is_kernels_available,
+ is_kornia_available,
is_librosa_available,
is_matplotlib_available,
+ is_nltk_available,
is_note_seq_available,
+ is_nvidia_modelopt_available,
+ is_nvidia_modelopt_version,
is_onnx_available,
+ is_opencv_available,
is_optimum_quanto_available,
is_optimum_quanto_version,
is_peft_available,
is_peft_version,
+ is_pytorch_retinaface_available,
is_safetensors_available,
+ is_sageattention_available,
+ is_sageattention_version,
is_scipy_available,
is_sentencepiece_available,
is_tensorboard_available,
is_timm_available,
is_torch_available,
+ is_torch_mlu_available,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
@@ -102,6 +122,7 @@
is_unidecode_available,
is_wandb_available,
is_xformers_available,
+ is_xformers_version,
requires_backends,
)
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
@@ -126,6 +147,7 @@
convert_state_dict_to_kohya,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
+ state_dict_all_zero,
)
from .typing_utils import _get_detailed_type, _is_valid_type
diff --git a/src/diffusers/utils/accelerate_utils.py b/src/diffusers/utils/accelerate_utils.py
index 99a8b3a47c25..af3b712b5a9e 100644
--- a/src/diffusers/utils/accelerate_utils.py
+++ b/src/diffusers/utils/accelerate_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index fa12318f4714..c46fa4363483 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -40,6 +40,12 @@
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
+DIFFUSERS_REQUEST_TIMEOUT = 60
+DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
+DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_VARS_TRUE_VALUES
+DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
+HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
+DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py
index f482deddd2f4..d76623541b9f 100644
--- a/src/diffusers/utils/deprecation_utils.py
+++ b/src/diffusers/utils/deprecation_utils.py
@@ -4,6 +4,54 @@
from packaging import version
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+# Mapping for deprecated Transformers classes to their replacements
+# This is used to handle models that reference deprecated class names in their configs
+# Reference: https://github.com/huggingface/transformers/issues/40822
+# Format: {
+# "DeprecatedClassName": {
+# "new_class": "NewClassName",
+# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
+# }
+# }
+_TRANSFORMERS_CLASS_REMAPPING = {
+ "CLIPFeatureExtractor": {
+ "new_class": "CLIPImageProcessor",
+ "transformers_version": (">", "4.57.0"),
+ },
+}
+
+
+def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
+ """
+ Check if a Transformers class should be remapped to a newer version.
+
+ Args:
+ class_name: The name of the class to check
+
+ Returns:
+ The new class name if remapping should occur, None otherwise
+ """
+ if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
+ return None
+
+ from .import_utils import is_transformers_version
+
+ mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
+ operation, required_version = mapping["transformers_version"]
+
+ # Only remap if the transformers version meets the requirement
+ if is_transformers_version(operation, required_version):
+ new_class = mapping["new_class"]
+ logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
+ return mapping["new_class"]
+
+ return None
+
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
from .. import __version__
@@ -40,7 +88,7 @@ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn
line_number = call_frame.lineno
function = call_frame.function
key, value = next(iter(deprecated_kwargs.items()))
- raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
+ raise TypeError(f"{function} in {filename} line {line_number - 1} got an unexpected keyword argument `{key}`")
if len(values) == 0:
return
diff --git a/src/diffusers/utils/doc_utils.py b/src/diffusers/utils/doc_utils.py
index fe633e683642..815083e14258 100644
--- a/src/diffusers/utils/doc_utils.py
+++ b/src/diffusers/utils/doc_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/utils/dummy_nvidia_modelopt_objects.py b/src/diffusers/utils/dummy_nvidia_modelopt_objects.py
new file mode 100644
index 000000000000..046b28223b3d
--- /dev/null
+++ b/src/diffusers/utils/dummy_nvidia_modelopt_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class NVIDIAModelOptConfig(metaclass=DummyObject):
+ _backends = ["nvidia_modelopt"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["nvidia_modelopt"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["nvidia_modelopt"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["nvidia_modelopt"])
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 6edbd737e32c..8628893200fe 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -2,6 +2,171 @@
from ..utils import DummyObject, requires_backends
+class AdaptiveProjectedGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AdaptiveProjectedMixGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class BaseGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ClassifierFreeGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class FrequencyDecoupledGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PerturbedAttentionGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class SkipLayerGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class SmoothedEnergyGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class TangentialClassifierFreeGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class FasterCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -17,7 +182,477 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class HookRegistry(metaclass=DummyObject):
+class FirstBlockCacheConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class HookRegistry(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class LayerSkipConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class SmoothedEnergyGuidanceConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class TaylorSeerCacheConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+def apply_faster_cache(*args, **kwargs):
+ requires_backends(apply_faster_cache, ["torch"])
+
+
+def apply_first_block_cache(*args, **kwargs):
+ requires_backends(apply_first_block_cache, ["torch"])
+
+
+def apply_layer_skip(*args, **kwargs):
+ requires_backends(apply_layer_skip, ["torch"])
+
+
+def apply_pyramid_attention_broadcast(*args, **kwargs):
+ requires_backends(apply_pyramid_attention_broadcast, ["torch"])
+
+
+def apply_taylorseer_cache(*args, **kwargs):
+ requires_backends(apply_taylorseer_cache, ["torch"])
+
+
+class AllegroTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AsymmetricAutoencoderKL(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AttentionBackendName(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AuraFlowTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderDC(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKL(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLAllegro(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLCogVideoX(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLCosmos(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLFlux2(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLHunyuanImage(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLHunyuanImageRefiner(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLLTXVideo(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLMagvit(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLMochi(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLQwenImage(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLWan(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderOobleck(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderTiny(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class BriaFiboTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class BriaTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -32,7 +667,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
+class CacheMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -47,15 +682,22 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-def apply_faster_cache(*args, **kwargs):
- requires_backends(apply_faster_cache, ["torch"])
+class ChromaTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
-def apply_pyramid_attention_broadcast(*args, **kwargs):
- requires_backends(apply_pyramid_attention_broadcast, ["torch"])
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
-class AllegroTransformer3DModel(metaclass=DummyObject):
+class ChronoEditTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -70,7 +712,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AsymmetricAutoencoderKL(metaclass=DummyObject):
+class CogVideoXTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -85,7 +727,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AuraFlowTransformer2DModel(metaclass=DummyObject):
+class CogView3PlusTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -100,7 +742,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderDC(metaclass=DummyObject):
+class CogView4Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -115,7 +757,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKL(metaclass=DummyObject):
+class ConsisIDTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -130,7 +772,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLAllegro(metaclass=DummyObject):
+class ConsistencyDecoderVAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -145,7 +787,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLCogVideoX(metaclass=DummyObject):
+class ContextParallelConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -160,7 +802,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
+class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -175,7 +817,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLLTXVideo(metaclass=DummyObject):
+class ControlNetUnionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -190,7 +832,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLMagvit(metaclass=DummyObject):
+class ControlNetXSAdapter(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -205,7 +847,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLMochi(metaclass=DummyObject):
+class CosmosTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -220,7 +862,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
+class DiTTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -235,7 +877,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderKLWan(metaclass=DummyObject):
+class EasyAnimateTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -250,7 +892,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderOobleck(metaclass=DummyObject):
+class Flux2Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -265,7 +907,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class AutoencoderTiny(metaclass=DummyObject):
+class FluxControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -280,7 +922,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class CacheMixin(metaclass=DummyObject):
+class FluxMultiControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -295,7 +937,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class CogVideoXTransformer3DModel(metaclass=DummyObject):
+class FluxTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -310,7 +952,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class CogView3PlusTransformer2DModel(metaclass=DummyObject):
+class HiDreamImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -325,7 +967,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class CogView4Transformer2DModel(metaclass=DummyObject):
+class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -340,7 +982,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class ConsisIDTransformer3DModel(metaclass=DummyObject):
+class HunyuanDiT2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -355,7 +997,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class ConsistencyDecoderVAE(metaclass=DummyObject):
+class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -370,7 +1012,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class ControlNetModel(metaclass=DummyObject):
+class HunyuanImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -385,7 +1027,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class ControlNetUnionModel(metaclass=DummyObject):
+class HunyuanVideo15Transformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -400,7 +1042,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class ControlNetXSAdapter(metaclass=DummyObject):
+class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -415,7 +1057,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class DiTTransformer2DModel(metaclass=DummyObject):
+class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -430,7 +1072,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class EasyAnimateTransformer3DModel(metaclass=DummyObject):
+class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -445,7 +1087,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class FluxControlNetModel(metaclass=DummyObject):
+class Kandinsky3UNet(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -460,7 +1102,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class FluxMultiControlNetModel(metaclass=DummyObject):
+class Kandinsky5Transformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -475,7 +1117,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class FluxTransformer2DModel(metaclass=DummyObject):
+class LatteTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -490,7 +1132,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
+class LTXVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -505,7 +1147,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class HunyuanDiT2DModel(metaclass=DummyObject):
+class Lumina2Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -520,7 +1162,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
+class LuminaNextDiT2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -535,7 +1177,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
+class MochiTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -550,7 +1192,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class I2VGenXLUNet(metaclass=DummyObject):
+class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -565,7 +1207,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class Kandinsky3UNet(metaclass=DummyObject):
+class MotionAdapter(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -580,7 +1222,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class LatteTransformer3DModel(metaclass=DummyObject):
+class MultiAdapter(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -595,7 +1237,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class LTXVideoTransformer3DModel(metaclass=DummyObject):
+class MultiControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -610,7 +1252,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class Lumina2Transformer2DModel(metaclass=DummyObject):
+class OmniGenTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -625,7 +1267,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class LuminaNextDiT2DModel(metaclass=DummyObject):
+class OvisImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -640,7 +1282,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class MochiTransformer3DModel(metaclass=DummyObject):
+class ParallelConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -655,7 +1297,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class ModelMixin(metaclass=DummyObject):
+class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -670,7 +1312,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class MotionAdapter(metaclass=DummyObject):
+class PriorTransformer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -685,7 +1327,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class MultiAdapter(metaclass=DummyObject):
+class PRXTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -700,7 +1342,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class MultiControlNetModel(metaclass=DummyObject):
+class QwenImageControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -715,7 +1357,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class OmniGenTransformer2DModel(metaclass=DummyObject):
+class QwenImageMultiControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -730,7 +1372,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class PixArtTransformer2DModel(metaclass=DummyObject):
+class QwenImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -745,7 +1387,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
-class PriorTransformer(metaclass=DummyObject):
+class SanaControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
@@ -775,6 +1417,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class SanaVideoTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class SD3ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -820,6 +1477,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class SkyReelsV2Transformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class SparseControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -895,6 +1567,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class TransformerTemporalModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1030,6 +1717,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class WanAnimateTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class WanTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1045,6 +1747,100 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class WanVACETransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ZImageTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+def attention_backend(*args, **kwargs):
+ requires_backends(attention_backend, ["torch"])
+
+
+class ComponentsManager(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ComponentSpec(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ModularPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ModularPipelineBlocks(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])
@@ -1703,6 +2499,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class FlowMatchLCMScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HeunDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py
new file mode 100644
index 000000000000..c0c4d9df5eee
--- /dev/null
+++ b/src/diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class ConsisIDPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers", "opencv"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers", "opencv"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "opencv"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers", "opencv"])
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index b28fba948149..79a21d2ac6e5 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2,6 +2,231 @@
from ..utils import DummyObject, requires_backends
+class FluxAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxKontextAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxKontextModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditPlusAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditPlusModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionXLModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class Wan22AutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WanAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WanModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AllegroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -62,7 +287,652 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AmusedInpaintPipeline(metaclass=DummyObject):
+class AmusedInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AmusedPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffPAGPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffSDXLPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffSparseControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDM2Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDM2ProjectionModel(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDM2UNet2DConditionModel(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AudioLDMPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class AuraFlowPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class BriaFiboPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class BriaPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ChromaImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ChromaPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ChronoEditPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CLIPImageProjection(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogVideoXFunControlPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogVideoXImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogVideoXPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogVideoXVideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogView3PlusPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogView4ControlPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CogView4Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ConsisIDPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class Cosmos2TextToImagePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class Cosmos2VideoToWorldPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CosmosTextToWorldPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CosmosVideoToWorldPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class CycleDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class EasyAnimateControlPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class EasyAnimateInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class EasyAnimatePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class Flux2Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxControlImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxControlInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxControlNetInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxControlPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxFillPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -77,7 +947,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AmusedPipeline(metaclass=DummyObject):
+class FluxImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -92,7 +962,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffControlNetPipeline(metaclass=DummyObject):
+class FluxInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -107,7 +977,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffPAGPipeline(metaclass=DummyObject):
+class FluxKontextInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -122,7 +992,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffPipeline(metaclass=DummyObject):
+class FluxKontextPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -137,7 +1007,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffSDXLPipeline(metaclass=DummyObject):
+class FluxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -152,7 +1022,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffSparseControlNetPipeline(metaclass=DummyObject):
+class FluxPriorReduxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -167,7 +1037,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject):
+class HiDreamImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -182,7 +1052,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
+class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -197,7 +1067,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AudioLDM2Pipeline(metaclass=DummyObject):
+class HunyuanDiTPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -212,7 +1082,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AudioLDM2ProjectionModel(metaclass=DummyObject):
+class HunyuanDiTPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -227,7 +1097,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AudioLDM2UNet2DConditionModel(metaclass=DummyObject):
+class HunyuanImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -242,7 +1112,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AudioLDMPipeline(metaclass=DummyObject):
+class HunyuanImageRefinerPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -257,7 +1127,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class AuraFlowPipeline(metaclass=DummyObject):
+class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -272,7 +1142,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CLIPImageProjection(metaclass=DummyObject):
+class HunyuanVideo15ImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -287,7 +1157,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogVideoXFunControlPipeline(metaclass=DummyObject):
+class HunyuanVideo15Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -302,7 +1172,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogVideoXImageToVideoPipeline(metaclass=DummyObject):
+class HunyuanVideoFramepackPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -317,7 +1187,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogVideoXPipeline(metaclass=DummyObject):
+class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -332,7 +1202,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogVideoXVideoToVideoPipeline(metaclass=DummyObject):
+class HunyuanVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -347,7 +1217,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogView3PlusPipeline(metaclass=DummyObject):
+class I2VGenXLPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -362,7 +1232,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogView4ControlPipeline(metaclass=DummyObject):
+class IFImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -377,7 +1247,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CogView4Pipeline(metaclass=DummyObject):
+class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -392,7 +1262,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class ConsisIDPipeline(metaclass=DummyObject):
+class IFInpaintingPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -407,7 +1277,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class CycleDiffusionPipeline(metaclass=DummyObject):
+class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -422,7 +1292,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class EasyAnimateControlPipeline(metaclass=DummyObject):
+class IFPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -437,7 +1307,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class EasyAnimateInpaintPipeline(metaclass=DummyObject):
+class IFSuperResolutionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -452,7 +1322,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class EasyAnimatePipeline(metaclass=DummyObject):
+class ImageTextPipelineOutput(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -467,7 +1337,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlImg2ImgPipeline(metaclass=DummyObject):
+class Kandinsky3Img2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -482,7 +1352,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlInpaintPipeline(metaclass=DummyObject):
+class Kandinsky3Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -497,7 +1367,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
+class Kandinsky5I2IPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -512,7 +1382,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetInpaintPipeline(metaclass=DummyObject):
+class Kandinsky5I2VPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -527,7 +1397,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlNetPipeline(metaclass=DummyObject):
+class Kandinsky5T2IPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -542,7 +1412,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxControlPipeline(metaclass=DummyObject):
+class Kandinsky5T2VPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -557,7 +1427,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxFillPipeline(metaclass=DummyObject):
+class KandinskyCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -572,7 +1442,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxImg2ImgPipeline(metaclass=DummyObject):
+class KandinskyImg2ImgCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -587,7 +1457,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxInpaintPipeline(metaclass=DummyObject):
+class KandinskyImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -602,7 +1472,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxPipeline(metaclass=DummyObject):
+class KandinskyInpaintCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -617,7 +1487,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class FluxPriorReduxPipeline(metaclass=DummyObject):
+class KandinskyInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -632,7 +1502,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanDiTControlNetPipeline(metaclass=DummyObject):
+class KandinskyPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -647,7 +1517,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanDiTPAGPipeline(metaclass=DummyObject):
+class KandinskyPriorPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -662,7 +1532,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanDiTPipeline(metaclass=DummyObject):
+class KandinskyV22CombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -677,7 +1547,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
+class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -692,7 +1562,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject):
+class KandinskyV22ControlnetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -707,7 +1577,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class HunyuanVideoPipeline(metaclass=DummyObject):
+class KandinskyV22Img2ImgCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -722,7 +1592,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class I2VGenXLPipeline(metaclass=DummyObject):
+class KandinskyV22Img2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -737,7 +1607,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFImg2ImgPipeline(metaclass=DummyObject):
+class KandinskyV22InpaintCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -752,7 +1622,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFImg2ImgSuperResolutionPipeline(metaclass=DummyObject):
+class KandinskyV22InpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -767,7 +1637,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFInpaintingPipeline(metaclass=DummyObject):
+class KandinskyV22Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -782,7 +1652,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFInpaintingSuperResolutionPipeline(metaclass=DummyObject):
+class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -797,7 +1667,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFPipeline(metaclass=DummyObject):
+class KandinskyV22PriorPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -812,7 +1682,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class IFSuperResolutionPipeline(metaclass=DummyObject):
+class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -827,7 +1697,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class ImageTextPipelineOutput(metaclass=DummyObject):
+class LatentConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -842,7 +1712,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class Kandinsky3Img2ImgPipeline(metaclass=DummyObject):
+class LattePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -857,7 +1727,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class Kandinsky3Pipeline(metaclass=DummyObject):
+class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -872,7 +1742,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyCombinedPipeline(metaclass=DummyObject):
+class LEditsPPPipelineStableDiffusion(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -887,7 +1757,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyImg2ImgCombinedPipeline(metaclass=DummyObject):
+class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -902,7 +1772,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyImg2ImgPipeline(metaclass=DummyObject):
+class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -917,7 +1787,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyInpaintCombinedPipeline(metaclass=DummyObject):
+class LTXImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -932,7 +1802,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyInpaintPipeline(metaclass=DummyObject):
+class LTXLatentUpsamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -947,7 +1817,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyPipeline(metaclass=DummyObject):
+class LTXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -962,7 +1832,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyPriorPipeline(metaclass=DummyObject):
+class LucyEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -977,7 +1847,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22CombinedPipeline(metaclass=DummyObject):
+class Lumina2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -992,7 +1862,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22ControlnetImg2ImgPipeline(metaclass=DummyObject):
+class Lumina2Text2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1007,7 +1877,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22ControlnetPipeline(metaclass=DummyObject):
+class LuminaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1022,7 +1892,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22Img2ImgCombinedPipeline(metaclass=DummyObject):
+class LuminaText2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1037,7 +1907,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22Img2ImgPipeline(metaclass=DummyObject):
+class MarigoldDepthPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1052,7 +1922,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22InpaintCombinedPipeline(metaclass=DummyObject):
+class MarigoldIntrinsicsPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1067,7 +1937,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22InpaintPipeline(metaclass=DummyObject):
+class MarigoldNormalsPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1082,7 +1952,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22Pipeline(metaclass=DummyObject):
+class MochiPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1097,7 +1967,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22PriorEmb2EmbPipeline(metaclass=DummyObject):
+class MusicLDMPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1112,7 +1982,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class KandinskyV22PriorPipeline(metaclass=DummyObject):
+class OmniGenPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1127,7 +1997,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LatentConsistencyModelImg2ImgPipeline(metaclass=DummyObject):
+class OvisImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1142,7 +2012,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LatentConsistencyModelPipeline(metaclass=DummyObject):
+class PaintByExamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1157,7 +2027,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LattePipeline(metaclass=DummyObject):
+class PIAPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1172,7 +2042,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LDMTextToImagePipeline(metaclass=DummyObject):
+class PixArtAlphaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1187,7 +2057,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LEditsPPPipelineStableDiffusion(metaclass=DummyObject):
+class PixArtSigmaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1202,7 +2072,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
+class PixArtSigmaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1217,7 +2087,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LTXConditionPipeline(metaclass=DummyObject):
+class PRXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1232,7 +2102,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LTXImageToVideoPipeline(metaclass=DummyObject):
+class QwenImageControlNetInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1247,7 +2117,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LTXPipeline(metaclass=DummyObject):
+class QwenImageControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1262,7 +2132,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class Lumina2Pipeline(metaclass=DummyObject):
+class QwenImageEditInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1277,7 +2147,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class Lumina2Text2ImgPipeline(metaclass=DummyObject):
+class QwenImageEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1292,7 +2162,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LuminaPipeline(metaclass=DummyObject):
+class QwenImageEditPlusPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1307,7 +2177,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class LuminaText2ImgPipeline(metaclass=DummyObject):
+class QwenImageImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1322,7 +2192,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class MarigoldDepthPipeline(metaclass=DummyObject):
+class QwenImageInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1337,7 +2207,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class MarigoldIntrinsicsPipeline(metaclass=DummyObject):
+class QwenImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1352,7 +2222,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class MarigoldNormalsPipeline(metaclass=DummyObject):
+class ReduxImageEncoder(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1367,7 +2237,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class MochiPipeline(metaclass=DummyObject):
+class SanaControlNetPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1382,7 +2252,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class MusicLDMPipeline(metaclass=DummyObject):
+class SanaImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1397,7 +2267,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class OmniGenPipeline(metaclass=DummyObject):
+class SanaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1412,7 +2282,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PaintByExamplePipeline(metaclass=DummyObject):
+class SanaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1427,7 +2297,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PIAPipeline(metaclass=DummyObject):
+class SanaSprintImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1442,7 +2312,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PixArtAlphaPipeline(metaclass=DummyObject):
+class SanaSprintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1457,7 +2327,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PixArtSigmaPAGPipeline(metaclass=DummyObject):
+class SanaVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1472,7 +2342,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class PixArtSigmaPipeline(metaclass=DummyObject):
+class SemanticStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1487,7 +2357,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class ReduxImageEncoder(metaclass=DummyObject):
+class ShapEImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1502,7 +2372,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class SanaPAGPipeline(metaclass=DummyObject):
+class ShapEPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1517,7 +2387,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class SanaPipeline(metaclass=DummyObject):
+class SkyReelsV2DiffusionForcingImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1532,7 +2402,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class SanaSprintPipeline(metaclass=DummyObject):
+class SkyReelsV2DiffusionForcingPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1547,7 +2417,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class SemanticStableDiffusionPipeline(metaclass=DummyObject):
+class SkyReelsV2DiffusionForcingVideoToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1562,7 +2432,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class ShapEImg2ImgPipeline(metaclass=DummyObject):
+class SkyReelsV2ImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -1577,7 +2447,7 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
-class ShapEPipeline(metaclass=DummyObject):
+class SkyReelsV2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -2717,6 +3587,36 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class VisualClozeGenerationPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class VisualClozePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -2732,6 +3632,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class WanAnimatePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class WanImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -2762,6 +3677,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class WanVACEPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class WanVideoToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -2820,3 +3750,33 @@ def from_config(cls, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+
+
+class ZImageImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class ZImagePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py
index 5d0752af8983..1c0734cf35bb 100644
--- a/src/diffusers/utils/dynamic_modules_utils.py
+++ b/src/diffusers/utils/dynamic_modules_utils.py
@@ -21,7 +21,9 @@
import re
import shutil
import sys
+import threading
from pathlib import Path
+from types import ModuleType
from typing import Dict, Optional, Union
from urllib import request
@@ -31,12 +33,15 @@
from .. import __version__
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
+TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
+_HF_REMOTE_CODE_LOCK = threading.Lock()
def get_diffusers_versions():
@@ -146,23 +151,68 @@ def check_imports(filename):
missing_packages.append(imp)
if len(missing_packages) > 0:
- raise ImportError(
- "This modeling file requires the following packages that were not found in your environment: "
+ logger.warning(
+ "This modeling file might require the following packages that were not found in your environment: "
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
)
return get_relative_imports(filename)
-def get_class_in_module(class_name, module_path):
+def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
+ trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
+ if DIFFUSERS_DISABLE_REMOTE_CODE:
+ logger.warning(
+ "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
+ )
+
+ if has_remote_code and not trust_remote_code:
+ error_msg = f"The repository for {model_name} contains custom code. "
+ error_msg += (
+ "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
+ if DIFFUSERS_DISABLE_REMOTE_CODE
+ else "Pass `trust_remote_code=True` to allow loading remote code modules."
+ )
+ raise ValueError(error_msg)
+
+ elif has_remote_code and trust_remote_code:
+ logger.warning(
+ f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
+ )
+
+ return trust_remote_code
+
+
+def get_class_in_module(class_name, module_path, force_reload=False):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
- module_path = module_path.replace(os.path.sep, ".")
- module = importlib.import_module(module_path)
+ name = os.path.normpath(module_path)
+ if name.endswith(".py"):
+ name = name[:-3]
+ name = name.replace(os.path.sep, ".")
+ module_file: Path = Path(HF_MODULES_CACHE) / module_path
+
+ with _HF_REMOTE_CODE_LOCK:
+ if force_reload:
+ sys.modules.pop(name, None)
+ importlib.invalidate_caches()
+ cached_module: Optional[ModuleType] = sys.modules.get(name)
+ module_spec = importlib.util.spec_from_file_location(name, location=module_file)
+
+ module: ModuleType
+ if cached_module is None:
+ module = importlib.util.module_from_spec(module_spec)
+ # insert it into sys.modules before any loading begins
+ sys.modules[name] = module
+ else:
+ module = cached_module
+
+ module_spec.loader.exec_module(module)
if class_name is None:
return find_pipeline_class(module)
+
return getattr(module, class_name)
@@ -197,12 +247,14 @@ def find_pipeline_class(loaded_module):
def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
+ subfolder: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
+ local_dir: Optional[str] = None,
):
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -239,12 +291,8 @@ def get_cached_module_file(
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
-
-
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
- [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
+ > [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
+ [gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
Returns:
`str`: The path to the module inside the cache.
@@ -285,6 +333,7 @@ def get_cached_module_file(
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
+ local_dir=local_dir,
)
submodule = "git"
module_file = pretrained_model_name_or_path + ".py"
@@ -303,10 +352,13 @@ def get_cached_module_file(
resolved_module_file = hf_hub_download(
pretrained_model_name_or_path,
module_file,
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
+ local_dir=local_dir,
+ revision=revision,
token=token,
)
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
@@ -360,12 +412,14 @@ def get_cached_module_file(
get_cached_module_file(
pretrained_model_name_or_path,
f"{module_needed}.py",
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
+ local_dir=local_dir,
)
return os.path.join(full_submodule, module_file)
@@ -374,6 +428,7 @@ def get_cached_module_file(
def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
+ subfolder: Optional[str] = None,
class_name: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
@@ -381,17 +436,13 @@ def get_class_from_dynamic_module(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
- **kwargs,
+ local_dir: Optional[str] = None,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.
-
-
- Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
- therefore only be called on trusted repos.
-
-
+ > [!WARNING] > Calling this function will execute the code in the module file found locally or downloaded from the
+ Hub. It should > therefore only be called on trusted repos.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
@@ -426,12 +477,8 @@ def get_class_from_dynamic_module(
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
-
-
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
- [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
+ > [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
+ [gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
Returns:
`type`: The class, dynamically imported from the module.
@@ -447,11 +494,13 @@ def get_class_from_dynamic_module(
final_module = get_cached_module_file(
pretrained_model_name_or_path,
module_file,
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
+ local_dir=local_dir,
)
- return get_class_in_module(class_name, final_module.replace(".py", ""))
+ return get_class_in_module(class_name, final_module)
diff --git a/src/diffusers/utils/export_utils.py b/src/diffusers/utils/export_utils.py
index 30d2c8bebd8e..07cf46928a44 100644
--- a/src/diffusers/utils/export_utils.py
+++ b/src/diffusers/utils/export_utils.py
@@ -155,7 +155,7 @@ def export_to_video(
bitrate:
Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead.
Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter
- rather than specifiying a fixed bitrate with this parameter.
+ rather than specifying a fixed bitrate with this parameter.
macro_block_size:
Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index f80f96a3425d..d0b05c7d9541 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -38,13 +38,13 @@
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
is_jinja_available,
validate_hf_hub_args,
)
from packaging import version
-from requests import HTTPError
from .. import __version__
from .constants import (
@@ -113,7 +113,8 @@ def load_or_create_model_card(
Args:
repo_id_or_path (`str`):
- The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
+ The repo id (e.g., "stable-diffusion-v1-5/stable-diffusion-v1-5") or local path where to look for the model
+ card.
token (`str`, *optional*):
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
details.
@@ -304,8 +305,7 @@ def _get_model_file(
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
- "token having permission to this repo with `token` or log in with `huggingface-cli "
- "login`."
+ "token having permission to this repo with `token` or log in with `hf auth login`."
) from e
except RevisionNotFoundError as e:
raise EnvironmentError(
@@ -317,7 +317,7 @@ def _get_model_file(
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
) from e
- except HTTPError as e:
+ except HfHubHTTPError as e:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{e}"
) from e
@@ -403,15 +403,17 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"]
- # `model_info` call must guarded with the above condition.
- model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
- for shard_file in original_shard_filenames:
- shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
- if not shard_file_present:
- raise EnvironmentError(
- f"{shards_path} does not appear to have a file named {shard_file} which is "
- "required according to the checkpoint index."
- )
+
+ # If the repo doesn't have the required shards, error out early even before downloading anything.
+ if not local_files_only:
+ model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
+ for shard_file in original_shard_filenames:
+ shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
+ if not shard_file_present:
+ raise EnvironmentError(
+ f"{shards_path} does not appear to have a file named {shard_file} which is "
+ "required according to the checkpoint index."
+ )
try:
# Load from URL
@@ -431,13 +433,18 @@ def _get_checkpoint_shard_files(
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
- except HTTPError as e:
+ except HfHubHTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
+ for cached_file in cached_filenames:
+ if not os.path.isfile(cached_file):
+ raise EnvironmentError(
+ f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
+ )
return cached_filenames, sharded_metadata
@@ -467,6 +474,7 @@ def _upload_folder(
token: Optional[str] = None,
commit_message: Optional[str] = None,
create_pr: bool = False,
+ subfolder: Optional[str] = None,
):
"""
Uploads all files in `working_dir` to `repo_id`.
@@ -481,7 +489,12 @@ def _upload_folder(
logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
return upload_folder(
- repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr
+ repo_id=repo_id,
+ folder_path=working_dir,
+ token=token,
+ commit_message=commit_message,
+ create_pr=create_pr,
+ path_in_repo=subfolder,
)
def push_to_hub(
@@ -493,6 +506,7 @@ def push_to_hub(
create_pr: bool = False,
safe_serialization: bool = True,
variant: Optional[str] = None,
+ subfolder: Optional[str] = None,
) -> str:
"""
Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub.
@@ -508,8 +522,8 @@ def push_to_hub(
Whether to make the repo private. If `None` (default), the repo will be public unless the
organization's default is private. This value is ignored if the repo already exists.
token (`str`, *optional*):
- The token to use as HTTP bearer authorization for remote files. The token generated when running
- `huggingface-cli login` (stored in `~/.huggingface`).
+ The token to use as HTTP bearer authorization for remote files. The token generated when running `hf
+ auth login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
safe_serialization (`bool`, *optional*, defaults to `True`):
@@ -534,8 +548,9 @@ def push_to_hub(
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
# Create a new empty model card and eventually tag it
- model_card = load_or_create_model_card(repo_id, token=token)
- model_card = populate_model_card(model_card)
+ if not subfolder:
+ model_card = load_or_create_model_card(repo_id, token=token)
+ model_card = populate_model_card(model_card)
# Save all files.
save_kwargs = {"safe_serialization": safe_serialization}
@@ -546,7 +561,8 @@ def push_to_hub(
self.save_pretrained(tmpdir, **save_kwargs)
# Update model card if needed:
- model_card.save(os.path.join(tmpdir, "README.md"))
+ if not subfolder:
+ model_card.save(os.path.join(tmpdir, "README.md"))
return self._upload_folder(
tmpdir,
@@ -554,4 +570,5 @@ def push_to_hub(
token=token,
commit_message=commit_message,
create_pr=create_pr,
+ subfolder=subfolder,
)
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index f61116aaaf6c..57b0a337922a 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,13 +16,15 @@
"""
import importlib.util
+import inspect
import operator as op
import os
import sys
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
+from functools import lru_cache as cache
from itertools import chain
from types import ModuleType
-from typing import Any, Union
+from typing import Any, Tuple, Union
from huggingface_hub.utils import is_jinja_available # noqa: F401
from packaging.version import Version, parse
@@ -35,7 +37,10 @@
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
-
+try:
+ _package_map = importlib_metadata.packages_distributions() # load-once to avoid expensive calls
+except Exception:
+ _package_map = None
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -54,12 +59,34 @@
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
-def _is_package_available(pkg_name: str):
+def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]:
+ global _package_map
pkg_exists = importlib.util.find_spec(pkg_name) is not None
pkg_version = "N/A"
if pkg_exists:
+ if _package_map is None:
+ _package_map = defaultdict(list)
+ try:
+ # Fallback for Python < 3.10
+ for dist in importlib_metadata.distributions():
+ _top_level_declared = (dist.read_text("top_level.txt") or "").split()
+ # Infer top-level package names from file structure
+ _inferred_opt_names = {
+ f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or [])
+ } - {None}
+ _top_level_inferred = filter(lambda name: "." not in name, _inferred_opt_names)
+ for pkg in _top_level_declared or _top_level_inferred:
+ _package_map[pkg].append(dist.metadata["Name"])
+ except Exception as _:
+ pass
try:
+ if get_dist_name and pkg_name in _package_map and _package_map[pkg_name]:
+ if len(_package_map[pkg_name]) > 1:
+ logger.warning(
+ f"Multiple distributions found for package {pkg_name}. Picked distribution: {_package_map[pkg_name][0]}"
+ )
+ pkg_name = _package_map[pkg_name][0]
pkg_version = importlib_metadata.version(pkg_name)
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
except (ImportError, importlib_metadata.PackageNotFoundError):
@@ -74,6 +101,7 @@ def _is_package_available(pkg_name: str):
else:
logger.info("Disabling PyTorch because USE_TORCH is set")
_torch_available = False
+ _torch_version = "N/A"
_jax_version = "N/A"
_flax_version = "N/A"
@@ -93,7 +121,7 @@ def _is_package_available(pkg_name: str):
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
else:
- logger.info("Disabling Safetensors because USE_TF is set")
+ logger.info("Disabling Safetensors because USE_SAFETENSORS is set")
_safetensors_available = False
_onnxruntime_version = "N/A"
@@ -101,18 +129,20 @@ def _is_package_available(pkg_name: str):
if _onnx_available:
candidates = (
"onnxruntime",
+ "onnxruntime-cann",
+ "onnxruntime-directml",
+ "ort_nightly_directml",
"onnxruntime-gpu",
"ort_nightly_gpu",
- "onnxruntime-directml",
+ "onnxruntime-migraphx",
"onnxruntime-openvino",
- "ort_nightly_directml",
+ "onnxruntime-qnn",
"onnxruntime-rocm",
- "onnxruntime-migraphx",
"onnxruntime-training",
"onnxruntime-vitisai",
)
_onnxruntime_version = None
- # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
+ # For the metadata, we have to look for both onnxruntime and onnxruntime-x
for pkg in candidates:
try:
_onnxruntime_version = importlib_metadata.version(pkg)
@@ -162,8 +192,10 @@ def _is_package_available(pkg_name: str):
_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
+_torch_mlu_available, _torch_mlu_version = _is_package_available("torch_mlu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
+_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -187,15 +219,17 @@ def _is_package_available(pkg_name: str):
_gguf_available, _gguf_version = _is_package_available("gguf")
_torchao_available, _torchao_version = _is_package_available("torchao")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
-_torchao_available, _torchao_version = _is_package_available("torchao")
-
-_optimum_quanto_available = importlib.util.find_spec("optimum") is not None
-if _optimum_quanto_available:
- try:
- _optimum_quanto_version = importlib_metadata.version("optimum_quanto")
- logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}")
- except importlib_metadata.PackageNotFoundError:
- _optimum_quanto_available = False
+_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
+_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
+_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
+_nltk_available, _nltk_version = _is_package_available("nltk")
+_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
+_sageattention_available, _sageattention_version = _is_package_available("sageattention")
+_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
+_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
+_aiter_available, _aiter_version = _is_package_available("aiter")
+_kornia_available, _kornia_version = _is_package_available("kornia")
+_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
def is_torch_available():
@@ -210,6 +244,10 @@ def is_torch_npu_available():
return _torch_npu_available
+def is_torch_mlu_available():
+ return _torch_mlu_available
+
+
def is_flax_available():
return _flax_available
@@ -250,6 +288,10 @@ def is_accelerate_available():
return _accelerate_available
+def is_kernels_available():
+ return _kernels_available
+
+
def is_k_diffusion_available():
return _k_diffusion_available
@@ -330,10 +372,54 @@ def is_optimum_quanto_available():
return _optimum_quanto_available
+def is_nvidia_modelopt_available():
+ return _nvidia_modelopt_available
+
+
def is_timm_available():
return _timm_available
+def is_pytorch_retinaface_available():
+ return _pytorch_retinaface_available
+
+
+def is_better_profanity_available():
+ return _better_profanity_available
+
+
+def is_nltk_available():
+ return _nltk_available
+
+
+def is_cosmos_guardrail_available():
+ return _cosmos_guardrail_available
+
+
+def is_hpu_available():
+ return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
+
+
+def is_sageattention_available():
+ return _sageattention_available
+
+
+def is_flash_attn_available():
+ return _flash_attn_available
+
+
+def is_flash_attn_3_available():
+ return _flash_attn_3_available
+
+
+def is_aiter_available():
+ return _aiter_available
+
+
+def is_kornia_available():
+ return _kornia_available
+
+
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -482,6 +568,22 @@ def is_timm_available():
install optimum-quanto`
"""
+# docstyle-ignore
+PYTORCH_RETINAFACE_IMPORT_ERROR = """
+{0} requires the pytorch_retinaface library but it was not found in your environment. You can install it with pip: `pip install pytorch_retinaface`
+"""
+
+# docstyle-ignore
+BETTER_PROFANITY_IMPORT_ERROR = """
+{0} requires the better_profanity library but it was not found in your environment. You can install it with pip: `pip install better_profanity`
+"""
+
+# docstyle-ignore
+NLTK_IMPORT_ERROR = """
+{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
+"""
+
+
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -510,6 +612,9 @@ def is_timm_available():
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
+ ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
+ ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
+ ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
]
)
@@ -579,6 +684,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
+@cache
def is_torch_version(operation: str, version: str):
"""
Compares the current PyTorch version to a given reference with an operation.
@@ -592,6 +698,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
+@cache
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
@@ -607,6 +714,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version)
+@cache
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
@@ -622,6 +730,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
+@cache
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -637,6 +746,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version)
+@cache
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -652,6 +762,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version)
+@cache
def is_peft_version(operation: str, version: str):
"""
Compares the current PEFT version to a given reference with an operation.
@@ -667,6 +778,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)
+@cache
def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
@@ -681,6 +793,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version)
+@cache
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -696,6 +809,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)
+@cache
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
@@ -711,6 +825,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version)
+@cache
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
@@ -726,6 +841,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)
+@cache
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -741,6 +857,86 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
+@cache
+def is_nvidia_modelopt_version(operation: str, version: str):
+ """
+ Compares the current Nvidia ModelOpt version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _nvidia_modelopt_available:
+ return False
+ return compare_versions(parse(_nvidia_modelopt_version), operation, version)
+
+
+@cache
+def is_xformers_version(operation: str, version: str):
+ """
+ Compares the current xformers version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _xformers_available:
+ return False
+ return compare_versions(parse(_xformers_version), operation, version)
+
+
+@cache
+def is_sageattention_version(operation: str, version: str):
+ """
+ Compares the current sageattention version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _sageattention_available:
+ return False
+ return compare_versions(parse(_sageattention_version), operation, version)
+
+
+@cache
+def is_flash_attn_version(operation: str, version: str):
+ """
+ Compares the current flash-attention version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _flash_attn_available:
+ return False
+ return compare_versions(parse(_flash_attn_version), operation, version)
+
+
+@cache
+def is_aiter_version(operation: str, version: str):
+ """
+ Compares the current aiter version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _aiter_available:
+ return False
+ return compare_versions(parse(_aiter_version), operation, version)
+
+
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py
index fd66aaa4da6e..dd23ae73c861 100644
--- a/src/diffusers/utils/loading_utils.py
+++ b/src/diffusers/utils/loading_utils.py
@@ -7,6 +7,7 @@
import PIL.ImageOps
import requests
+from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .import_utils import BACKENDS_MAPPING, is_imageio_available
@@ -29,7 +30,7 @@ def load_image(
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
- image = PIL.Image.open(requests.get(image, stream=True).raw)
+ image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py
index 6f93450c410c..2ad6d3a47607 100644
--- a/src/diffusers/utils/logging.py
+++ b/src/diffusers/utils/logging.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 Optuna, Hugging Face
+# Copyright 2025 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -60,8 +60,7 @@ def _get_default_logging_level() -> int:
return log_levels[env_level_str]
else:
logging.getLogger().warning(
- f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
- f"has to be one of: { ', '.join(log_levels.keys()) }"
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, has to be one of: {', '.join(log_levels.keys())}"
)
return _default_log_level
diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py
index 6080a86b871a..2b20f6120ce3 100644
--- a/src/diffusers/utils/outputs.py
+++ b/src/diffusers/utils/outputs.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -43,12 +43,8 @@ class BaseOutput(OrderedDict):
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
Python dictionary.
-
-
- You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
- first.
-
-
+ > [!WARNING] > You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert
+ it to a tuple > first.
"""
def __init_subclass__(cls) -> None:
@@ -71,6 +67,7 @@ def __init_subclass__(cls) -> None:
cls,
torch.utils._pytree._dict_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
+ serialized_type_name=f"{cls.__module__}.{cls.__name__}",
)
def __post_init__(self) -> None:
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index d1269fbc5f20..12066ee3f89b 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,9 +21,13 @@
from packaging import version
-from .import_utils import is_peft_available, is_torch_available
+from . import logging
+from .import_utils import is_peft_available, is_peft_version, is_torch_available
+from .torch_utils import empty_device_cache
+logger = logging.get_logger(__name__)
+
if is_torch_available():
import torch
@@ -95,8 +99,7 @@ def recurse_remove_peft_layers(model):
setattr(model, name, new_module)
del module
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
+ empty_device_cache()
return model
@@ -147,25 +150,27 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
-def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
+def get_peft_kwargs(
+ rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
+):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
- # get the rank occuring the most number of times
+ # get the rank occurring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
- # for modules with rank different from the most occuring rank, add it to the `rank_pattern`
+ # for modules with rank different from the most occurring rank, add it to the `rank_pattern`
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1:
- # get the alpha occuring the most number of times
+ # get the alpha occurring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
- # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
+ # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
if is_unet:
alpha_pattern = {
@@ -177,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
else:
lora_alpha = set(network_alpha_dict.values()).pop()
- # layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# for now we know that the "bias" keys are only associated with `lora_B`.
@@ -192,6 +196,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
"use_dora": use_dora,
"lora_bias": lora_bias,
}
+
return lora_config_kwargs
@@ -288,3 +293,84 @@ def check_peft_version(min_version: str) -> None:
f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}"
)
+
+
+def _create_lora_config(
+ state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
+):
+ from peft import LoraConfig
+
+ if metadata is not None:
+ lora_config_kwargs = metadata
+ else:
+ lora_config_kwargs = get_peft_kwargs(
+ rank_pattern_dict,
+ network_alpha_dict=network_alphas,
+ peft_state_dict=state_dict,
+ is_unet=is_unet,
+ model_state_dict=model_state_dict,
+ adapter_name=adapter_name,
+ )
+
+ _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
+
+ # Version checks for DoRA and lora_bias
+ if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
+ if is_peft_version("<", "0.9.0"):
+ raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")
+
+ if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]:
+ if is_peft_version("<=", "0.13.2"):
+ raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.")
+
+ try:
+ return LoraConfig(**lora_config_kwargs)
+ except TypeError as e:
+ raise TypeError("`LoraConfig` class could not be instantiated.") from e
+
+
+def _maybe_raise_error_for_ambiguous_keys(config):
+ rank_pattern = config["rank_pattern"].copy()
+ target_modules = config["target_modules"]
+
+ for key in list(rank_pattern.keys()):
+ # try to detect ambiguity
+ # `target_modules` can also be a str, in which case this loop would loop
+ # over the chars of the str. The technically correct way to match LoRA keys
+ # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
+ # But this cuts it for now.
+ exact_matches = [mod for mod in target_modules if mod == key]
+ substring_matches = [mod for mod in target_modules if key in mod and mod != key]
+
+ if exact_matches and substring_matches:
+ if is_peft_version("<", "0.14.1"):
+ raise ValueError(
+ "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
+ )
+
+
+def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
+ warn_msg = ""
+ if incompatible_keys is not None:
+ # Check only for unexpected keys.
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
+ if lora_unexpected_keys:
+ warn_msg = (
+ f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
+ f" {', '.join(lora_unexpected_keys)}. "
+ )
+
+ # Filter missing keys specific to the current adapter.
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
+ if missing_keys:
+ lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
+ if lora_missing_keys:
+ warn_msg += (
+ f"Loading adapter weights from state_dict led to missing keys in the model:"
+ f" {', '.join(lora_missing_keys)}."
+ )
+
+ if warn_msg:
+ logger.warning(warn_msg)
diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py
index 62b114ba67e3..50bfce8b15eb 100644
--- a/src/diffusers/utils/state_dict_utils.py
+++ b/src/diffusers/utils/state_dict_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +16,16 @@
"""
import enum
+import json
+from .import_utils import is_torch_available
from .logging import get_logger
+if is_torch_available():
+ import torch
+
+
logger = get_logger(__name__)
@@ -214,7 +220,7 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
+ - **adapter_name**: For example, in case of PEFT, some keys will be prepended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
@@ -285,7 +291,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kwargs (`dict`, *args*):
Additional arguments to pass to the method.
- - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
+ - **adapter_name**: For example, in case of PEFT, some keys will be prepended
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
`get_peft_model_state_dict` method:
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
@@ -329,7 +335,32 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
kohya_ss_state_dict[kohya_key] = weight
if "lora_down" in kohya_key:
- alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(len(weight))
return kohya_ss_state_dict
+
+
+def state_dict_all_zero(state_dict, filter_str=None):
+ if filter_str is not None:
+ if isinstance(filter_str, str):
+ filter_str = [filter_str]
+ state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
+
+ return all(torch.all(param == 0).item() for param in state_dict.values())
+
+
+def _load_sft_state_dict_metadata(model_file: str):
+ import safetensors.torch
+
+ from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
+
+ with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
+ metadata = f.metadata() or {}
+
+ metadata.pop("format", None)
+ if metadata:
+ raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
+ return json.loads(raw) if raw else None
+ else:
+ return None
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index e62f245f9ed1..3297bb5fdcd6 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -1,4 +1,5 @@
import functools
+import glob
import importlib
import importlib.metadata
import inspect
@@ -14,10 +15,11 @@
import time
import unittest
import urllib.parse
+from collections import UserDict
from contextlib import contextmanager
from io import BytesIO, StringIO
from pathlib import Path
-from typing import Callable, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import PIL.Image
@@ -26,6 +28,7 @@
from numpy.linalg import norm
from packaging import version
+from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .import_utils import (
BACKENDS_MAPPING,
is_accelerate_available,
@@ -33,9 +36,12 @@
is_compel_available,
is_flax_available,
is_gguf_available,
+ is_kernels_available,
is_note_seq_available,
+ is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
+ is_optimum_quanto_available,
is_peft_available,
is_timm_available,
is_torch_available,
@@ -47,10 +53,24 @@
from .logging import get_logger
+if is_torch_available():
+ import torch
+
+ IS_ROCM_SYSTEM = torch.version.hip is not None
+ IS_CUDA_SYSTEM = torch.version.cuda is not None
+ IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
+else:
+ IS_ROCM_SYSTEM = False
+ IS_CUDA_SYSTEM = False
+ IS_XPU_SYSTEM = False
+
global_rng = random.Random()
logger = get_logger(__name__)
-
+logger.warning(
+ "diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
+ "Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
+)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
@@ -119,6 +139,29 @@ def numpy_cosine_similarity_distance(a, b):
return distance
+def check_if_dicts_are_equal(dict1, dict2):
+ dict1, dict2 = dict1.copy(), dict2.copy()
+
+ for key, value in dict1.items():
+ if isinstance(value, set):
+ dict1[key] = sorted(value)
+ for key, value in dict2.items():
+ if isinstance(value, set):
+ dict2[key] = sorted(value)
+
+ for key in dict1:
+ if key not in dict2:
+ return False
+ if dict1[key] != dict2[key]:
+ return False
+
+ for key in dict2:
+ if key not in dict1:
+ return False
+
+ return True
+
+
def print_tensor_test(
tensor,
limit_to_slices=None,
@@ -277,6 +320,18 @@ def decorator(test_case):
return decorator
+def require_torch_version_greater(torch_version):
+ """Decorator marking a test that requires torch with a specific version greater."""
+
+ def decorator(test_case):
+ correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
+ return unittest.skipUnless(
+ correct_torch_version, f"test requires torch with the version greater than {torch_version}"
+ )(test_case)
+
+ return decorator
+
+
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
@@ -286,9 +341,7 @@ def require_torch_gpu(test_case):
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
- if not torch.cuda.is_available():
- return unittest.skip(test_case)
- else:
+ if torch.cuda.is_available():
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),
@@ -374,6 +427,10 @@ def require_big_accelerator(test_case):
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
Flux, SD3, Cog, etc.
"""
+ import pytest
+
+ test_case = pytest.mark.big_accelerator(test_case)
+
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
@@ -473,6 +530,13 @@ def require_bitsandbytes(test_case):
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
+def require_quanto(test_case):
+ """
+ Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
+ """
+ return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
+
+
def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
@@ -575,6 +639,30 @@ def decorator(test_case):
return decorator
+def require_modelopt_version_greater_or_equal(modelopt_version):
+ def decorator(test_case):
+ correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
+ version.parse(importlib.metadata.version("modelopt")).base_version
+ ) >= version.parse(modelopt_version)
+ return unittest.skipUnless(
+ correct_nvidia_modelopt_version, f"Test requires modelopt with version greater than {modelopt_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_kernels_version_greater_or_equal(kernels_version):
+ def decorator(test_case):
+ correct_kernels_version = is_kernels_available() and version.parse(
+ version.parse(importlib.metadata.version("kernels")).base_version
+ ) >= version.parse(kernels_version)
+ return unittest.skipUnless(
+ correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
+ )(test_case)
+
+ return decorator
+
+
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
@@ -594,7 +682,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
# local_path can be passed to correct images of tests
return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
elif arry.startswith("http://") or arry.startswith("https://"):
- response = requests.get(arry)
+ response = requests.get(arry, timeout=DIFFUSERS_REQUEST_TIMEOUT)
response.raise_for_status()
arry = np.load(BytesIO(response.content))
elif os.path.isfile(arry):
@@ -614,10 +702,10 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
return arry
-def load_pt(url: str, map_location: str):
- response = requests.get(url)
+def load_pt(url: str, map_location: Optional[str] = None, weights_only: Optional[bool] = True):
+ response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
response.raise_for_status()
- arry = torch.load(BytesIO(response.content), map_location=map_location)
+ arry = torch.load(BytesIO(response.content), map_location=map_location, weights_only=weights_only)
return arry
@@ -634,7 +722,7 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
- image = PIL.Image.open(requests.get(image, stream=True).raw)
+ image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
@@ -729,10 +817,9 @@ def export_to_ply(mesh, output_ply_path: str = None):
f.write(format.pack(*vertex))
if faces is not None:
- format = struct.Struct(" DeviceProperties:
+ """
+ Get environment device properties.
+ """
+ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
+ import torch
+
+ major, _ = torch.cuda.get_device_capability()
+ if IS_ROCM_SYSTEM:
+ return ("rocm", major)
+ else:
+ return ("cuda", major)
+ elif IS_XPU_SYSTEM:
+ import torch
+
+ # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
+ arch = torch.xpu.get_device_capability()["architecture"]
+ gen_mask = 0x000000FF00000000
+ gen = (arch & gen_mask) >> 32
+ return ("xpu", gen)
+ else:
+ return (torch_device, None)
+
+
+if TYPE_CHECKING:
+ DevicePropertiesUserDict = UserDict[DeviceProperties, Any]
+else:
+ DevicePropertiesUserDict = UserDict
+
+if is_torch_available():
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks.group_offloading import (
+ _GROUP_ID_LAZY_LEAF,
+ _compute_group_hash,
+ _find_parent_module_in_module_dict,
+ _gather_buffers_with_no_group_offloading_parent,
+ _gather_parameters_with_no_group_offloading_parent,
+ )
+
+ def _get_expected_safetensors_files(
+ module: torch.nn.Module,
+ offload_to_disk_path: str,
+ offload_type: str,
+ num_blocks_per_group: Optional[int] = None,
+ ) -> Set[str]:
+ expected_files = set()
+
+ def get_hashed_filename(group_id: str) -> str:
+ short_hash = _compute_group_hash(group_id)
+ return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
+
+ if offload_type == "block_level":
+ if num_blocks_per_group is None:
+ raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
+
+ # Handle groups of ModuleList and Sequential blocks
+ unmatched_modules = []
+ for name, submodule in module.named_children():
+ if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
+ unmatched_modules.append(module)
+ continue
+
+ for i in range(0, len(submodule), num_blocks_per_group):
+ current_modules = submodule[i : i + num_blocks_per_group]
+ if not current_modules:
+ continue
+ group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
+ expected_files.add(get_hashed_filename(group_id))
+
+ # Handle the group for unmatched top-level modules and parameters
+ for module in unmatched_modules:
+ expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
+
+ elif offload_type == "leaf_level":
+ # Handle leaf-level module groups
+ for name, submodule in module.named_modules():
+ if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
+ # These groups will always have parameters, so a file is expected
+ expected_files.add(get_hashed_filename(name))
+
+ # Handle groups for non-leaf parameters/buffers
+ modules_with_group_offloading = {
+ name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
+ }
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
+
+ all_orphans = parameters + buffers
+ if all_orphans:
+ parent_to_tensors = {}
+ module_dict = dict(module.named_modules())
+ for tensor_name, _ in all_orphans:
+ parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
+ if parent_name not in parent_to_tensors:
+ parent_to_tensors[parent_name] = []
+ parent_to_tensors[parent_name].append(tensor_name)
+
+ for parent_name in parent_to_tensors:
+ # A file is expected for each parent that gathers orphaned tensors
+ expected_files.add(get_hashed_filename(parent_name))
+ expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
+
+ else:
+ raise ValueError(f"Unsupported offload_type: {offload_type}")
+
+ return expected_files
+
+ def _check_safetensors_serialization(
+ module: torch.nn.Module,
+ offload_to_disk_path: str,
+ offload_type: str,
+ num_blocks_per_group: Optional[int] = None,
+ ) -> bool:
+ if not os.path.isdir(offload_to_disk_path):
+ return False, None, None
+
+ expected_files = _get_expected_safetensors_files(
+ module, offload_to_disk_path, offload_type, num_blocks_per_group
+ )
+ actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
+ missing_files = expected_files - actual_files
+ extra_files = actual_files - expected_files
+
+ is_correct = not missing_files and not extra_files
+ return is_correct, extra_files, missing_files
+
+
+class Expectations(DevicePropertiesUserDict):
+ def get_expectation(self) -> Any:
+ """
+ Find best matching expectation based on environment device properties.
+ """
+ return self.find_expectation(get_device_properties())
+
+ @staticmethod
+ def is_default(key: DeviceProperties) -> bool:
+ return all(p is None for p in key)
+
+ @staticmethod
+ def score(key: DeviceProperties, other: DeviceProperties) -> int:
+ """
+ Returns score indicating how similar two instances of the `Properties` tuple are. Points are calculated using
+ bits, but documented as int. Rules are as follows:
+ * Matching `type` gives 8 points.
+ * Semi-matching `type`, for example cuda and rocm, gives 4 points.
+ * Matching `major` (compute capability major version) gives 2 points.
+ * Default expectation (if present) gives 1 points.
+ """
+ (device_type, major) = key
+ (other_device_type, other_major) = other
+
+ score = 0b0
+ if device_type == other_device_type:
+ score |= 0b1000
+ elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
+ score |= 0b100
+
+ if major == other_major and other_major is not None:
+ score |= 0b10
+
+ if Expectations.is_default(other):
+ score |= 0b1
+
+ return int(score)
+
+ def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
+ """
+ Find best matching expectation based on provided device properties.
+ """
+ (result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
+
+ if Expectations.score(key, result_key) == 0:
+ raise ValueError(f"No matching expectation found for {key}")
+
+ return result
+
+ def __repr__(self):
+ return f"{self.data}"
diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py
index 06be5cb961ac..c5273ddd9e9f 100644
--- a/src/diffusers/utils/torch_utils.py
+++ b/src/diffusers/utils/torch_utils.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,16 +14,68 @@
"""
PyTorch utilities: Utilities related to PyTorch
"""
+
import functools
-from typing import List, Optional, Tuple, Union
+import os
+from typing import Callable, Dict, List, Optional, Tuple, Union
from . import logging
-from .import_utils import is_torch_available, is_torch_version
+from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
+ BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
+ BACKEND_EMPTY_CACHE = {
+ "cuda": torch.cuda.empty_cache,
+ "xpu": torch.xpu.empty_cache,
+ "cpu": None,
+ "mps": torch.mps.empty_cache,
+ "default": None,
+ }
+ BACKEND_DEVICE_COUNT = {
+ "cuda": torch.cuda.device_count,
+ "xpu": torch.xpu.device_count,
+ "cpu": lambda: 0,
+ "mps": lambda: 0,
+ "default": 0,
+ }
+ BACKEND_MANUAL_SEED = {
+ "cuda": torch.cuda.manual_seed,
+ "xpu": torch.xpu.manual_seed,
+ "cpu": torch.manual_seed,
+ "mps": torch.mps.manual_seed,
+ "default": torch.manual_seed,
+ }
+ BACKEND_RESET_PEAK_MEMORY_STATS = {
+ "cuda": torch.cuda.reset_peak_memory_stats,
+ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+ BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.reset_max_memory_allocated,
+ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+ BACKEND_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.max_memory_allocated,
+ "xpu": getattr(torch.xpu, "max_memory_allocated", None),
+ "cpu": 0,
+ "mps": 0,
+ "default": 0,
+ }
+ BACKEND_SYNCHRONIZE = {
+ "cuda": torch.cuda.synchronize,
+ "xpu": getattr(torch.xpu, "synchronize", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
@@ -34,10 +86,66 @@ def maybe_allow_in_graph(cls):
return cls
+# This dispatches a defined function according to the accelerator from the function definitions.
+def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
+ if device not in dispatch_table:
+ return dispatch_table["default"](*args, **kwargs)
+
+ fn = dispatch_table[device]
+
+ # Some device agnostic functions return values. Need to guard against 'None' instead at
+ # user level
+ if not callable(fn):
+ return fn
+
+ return fn(*args, **kwargs)
+
+
+# These are callables which automatically dispatch the function specific to the accelerator
+def backend_manual_seed(device: str, seed: int):
+ return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
+
+
+def backend_synchronize(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
+
+
+def backend_empty_cache(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
+
+
+def backend_device_count(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
+
+
+def backend_reset_peak_memory_stats(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
+
+
+def backend_reset_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
+
+
+def backend_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
+
+
+# These are callables which return boolean behaviour flags and can be used to specify some
+# device agnostic alternative where the feature is unsupported.
+def backend_supports_training(device: str):
+ if not is_torch_available():
+ return False
+
+ if device not in BACKEND_SUPPORTS_TRAINING:
+ device = "default"
+
+ return BACKEND_SUPPORTS_TRAINING[device]
+
+
def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
- device: Optional["torch.device"] = None,
+ device: Optional[Union[str, "torch.device"]] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
):
@@ -46,6 +154,8 @@ def randn_tensor(
is always created on the CPU.
"""
# device on which tensor is created defaults to device
+ if isinstance(device, str):
+ device = torch.device(device)
rand_device = device
batch_size = shape[0]
@@ -64,7 +174,7 @@ def randn_tensor(
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
- f" slighly speed up this function by passing a generator that was created on the {device} device."
+ f" slightly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(
@@ -103,8 +213,13 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
+def unwrap_module(module):
+ """Unwraps a module if it was compiled with torch.compile()"""
+ return module._orig_mod if is_compiled_module(module) else module
+
+
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
- """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
+ """Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
This version of the method comes from here:
https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
@@ -145,8 +260,8 @@ def apply_freeu(
res_hidden_states: "torch.Tensor",
**freeu_kwargs,
) -> Tuple["torch.Tensor", "torch.Tensor"]:
- """Applies the FreeU mechanism as introduced in https:
- //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
+ """Applies the FreeU mechanism as introduced in https://huggingface.co/papers/2309.11497. Adapted from the official
+ code repository: https://github.com/ChenyangSi/FreeU.
Args:
resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
@@ -197,5 +312,51 @@ def get_device():
return "xpu"
elif torch.backends.mps.is_available():
return "mps"
+ elif is_torch_mlu_available():
+ return "mlu"
else:
return "cpu"
+
+
+def empty_device_cache(device_type: Optional[str] = None):
+ if device_type is None:
+ device_type = get_device()
+ if device_type in ["cpu"]:
+ return
+ device_mod = getattr(torch, device_type, torch.cuda)
+ device_mod.empty_cache()
+
+
+def device_synchronize(device_type: Optional[str] = None):
+ if device_type is None:
+ device_type = get_device()
+ device_mod = getattr(torch, device_type, torch.cuda)
+ device_mod.synchronize()
+
+
+def enable_full_determinism():
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cuda.matmul.allow_tf32 = False
+
+
+def disable_full_determinism():
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
+ torch.use_deterministic_algorithms(False)
+
+
+if is_torch_available():
+ torch_device = get_device()
diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py
index 2da782b463d4..abeb30bca102 100644
--- a/src/diffusers/video_processor.py
+++ b/src/diffusers/video_processor.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,11 +13,12 @@
# limitations under the License.
import warnings
-from typing import List, Optional, Union
+from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
+import torch.nn.functional as F
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
@@ -67,7 +68,7 @@ def preprocess_video(self, video, height: Optional[int] = None, width: Optional[
# ensure the input is a list of videos:
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
- # - if it is a single video, it is convereted to a list of one video.
+ # - if it is a single video, it is converted to a list of one video.
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
video = list(video)
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
@@ -111,3 +112,65 @@ def postprocess_video(
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
+
+ @staticmethod
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
+ r"""
+ Returns the binned height and width based on the aspect ratio.
+
+ Args:
+ height (`int`): The height of the image.
+ width (`int`): The width of the image.
+ ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
+
+ Returns:
+ `Tuple[int, int]`: The closest binned height and width.
+ """
+ ar = float(height / width)
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
+ default_hw = ratios[closest_ratio]
+ return int(default_hw[0]), int(default_hw[1])
+
+ @staticmethod
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
+ r"""
+ Resizes and crops a tensor of videos to the specified dimensions.
+
+ Args:
+ samples (`torch.Tensor`):
+ A tensor of shape (N, C, T, H, W) where N is the batch size, C is the number of channels, T is the
+ number of frames, H is the height, and W is the width.
+ new_width (`int`): The desired width of the output videos.
+ new_height (`int`): The desired height of the output videos.
+
+ Returns:
+ `torch.Tensor`: A tensor containing the resized and cropped videos.
+ """
+ orig_height, orig_width = samples.shape[3], samples.shape[4]
+
+ # Check if resizing is needed
+ if orig_height != new_height or orig_width != new_width:
+ ratio = max(new_height / orig_height, new_width / orig_width)
+ resized_width = int(orig_width * ratio)
+ resized_height = int(orig_height * ratio)
+
+ # Reshape to (N*T, C, H, W) for interpolation
+ n, c, t, h, w = samples.shape
+ samples = samples.permute(0, 2, 1, 3, 4).reshape(n * t, c, h, w)
+
+ # Resize
+ samples = F.interpolate(
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ # Center Crop
+ start_x = (resized_width - new_width) // 2
+ end_x = start_x + new_width
+ start_y = (resized_height - new_height) // 2
+ end_y = start_y + new_height
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
+
+ # Reshape back to (N, C, T, H, W)
+ samples = samples.reshape(n, t, c, new_height, new_width).permute(0, 2, 1, 3, 4)
+
+ return samples
diff --git a/tests/conftest.py b/tests/conftest.py
index 4993ed94e8f1..fd76d1c84ee7 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,14 +30,18 @@
warnings.simplefilter(action="ignore", category=FutureWarning)
+def pytest_configure(config):
+ config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
+
+
def pytest_addoption(parser):
- from diffusers.utils.testing_utils import pytest_addoption_shared
+ from .testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
- from diffusers.utils.testing_utils import pytest_terminal_summary_main
+ from .testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py
index e197cb6859fa..25673e566549 100644
--- a/tests/fixtures/custom_pipeline/pipeline.py
+++ b/tests/fixtures/custom_pipeline/pipeline.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
# limitations under the License.
diff --git a/tests/fixtures/custom_pipeline/what_ever.py b/tests/fixtures/custom_pipeline/what_ever.py
index bbe7f4f16bd8..7504940780e8 100644
--- a/tests/fixtures/custom_pipeline/what_ever.py
+++ b/tests/fixtures/custom_pipeline/what_ever.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
# limitations under the License.
diff --git a/tests/pipelines/audioldm/__init__.py b/tests/hooks/__init__.py
similarity index 100%
rename from tests/pipelines/audioldm/__init__.py
rename to tests/hooks/__init__.py
diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py
index d8f41fc2b1ae..236094109d07 100644
--- a/tests/hooks/test_group_offloading.py
+++ b/tests/hooks/test_group_offloading.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,15 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import contextlib
import gc
import unittest
import torch
+from parameterized import parameterized
+from diffusers import AutoencoderKL
+from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
-from diffusers.utils.testing_utils import require_torch_gpu, torch_device
+from diffusers.utils.import_utils import compare_versions
+
+from ..testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ require_torch_accelerator,
+ torch_device,
+)
class DummyBlock(torch.nn.Module):
@@ -58,6 +70,62 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
+# This model implementation contains one type of block (single_blocks) instantiated before another type of block (double_blocks).
+# The invocation order of these blocks, however, is first the double_blocks and then the single_blocks.
+# With group offloading implementation before https://github.com/huggingface/diffusers/pull/11375, such a modeling implementation
+# would result in a device mismatch error because of the assumptions made by the code. The failure case occurs when using:
+# offload_type="block_level", num_blocks_per_group=2, use_stream=True
+# Post the linked PR, the implementation will work as expected.
+class DummyModelWithMultipleBlocks(ModelMixin):
+ def __init__(
+ self, in_features: int, hidden_features: int, out_features: int, num_layers: int, num_single_layers: int
+ ) -> None:
+ super().__init__()
+
+ self.linear_1 = torch.nn.Linear(in_features, hidden_features)
+ self.activation = torch.nn.ReLU()
+ self.single_blocks = torch.nn.ModuleList(
+ [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_single_layers)]
+ )
+ self.double_blocks = torch.nn.ModuleList(
+ [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
+ )
+ self.linear_2 = torch.nn.Linear(hidden_features, out_features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear_1(x)
+ x = self.activation(x)
+ for block in self.double_blocks:
+ x = block(x)
+ for block in self.single_blocks:
+ x = block(x)
+ x = self.linear_2(x)
+ return x
+
+
+# Test for https://github.com/huggingface/diffusers/pull/12077
+class DummyModelWithLayerNorm(ModelMixin):
+ def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
+ super().__init__()
+
+ self.linear_1 = torch.nn.Linear(in_features, hidden_features)
+ self.activation = torch.nn.ReLU()
+ self.blocks = torch.nn.ModuleList(
+ [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
+ )
+ self.layer_norm = torch.nn.LayerNorm(hidden_features, elementwise_affine=True)
+ self.linear_2 = torch.nn.Linear(hidden_features, out_features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear_1(x)
+ x = self.activation(x)
+ for block in self.blocks:
+ x = block(x)
+ x = self.layer_norm(x)
+ x = self.linear_2(x)
+ return x
+
+
class DummyPipeline(DiffusionPipeline):
model_cpu_offload_seq = "model"
@@ -72,7 +140,85 @@ def __call__(self, x: torch.Tensor) -> torch.Tensor:
return x
-@require_torch_gpu
+class LayerOutputTrackerHook(ModelHook):
+ def __init__(self):
+ super().__init__()
+ self.outputs = []
+
+ def post_forward(self, module, output):
+ self.outputs.append(output)
+ return output
+
+
+# Model with only standalone computational layers at top level
+class DummyModelWithStandaloneLayers(ModelMixin):
+ def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
+ super().__init__()
+
+ self.layer1 = torch.nn.Linear(in_features, hidden_features)
+ self.activation = torch.nn.ReLU()
+ self.layer2 = torch.nn.Linear(hidden_features, hidden_features)
+ self.layer3 = torch.nn.Linear(hidden_features, out_features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.layer1(x)
+ x = self.activation(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ return x
+
+
+# Model with deeply nested structure
+class DummyModelWithDeeplyNestedBlocks(ModelMixin):
+ def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None:
+ super().__init__()
+
+ self.input_layer = torch.nn.Linear(in_features, hidden_features)
+ self.container = ContainerWithNestedModuleList(hidden_features)
+ self.output_layer = torch.nn.Linear(hidden_features, out_features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.input_layer(x)
+ x = self.container(x)
+ x = self.output_layer(x)
+ return x
+
+
+class ContainerWithNestedModuleList(torch.nn.Module):
+ def __init__(self, features: int) -> None:
+ super().__init__()
+
+ # Top-level computational layer
+ self.proj_in = torch.nn.Linear(features, features)
+
+ # Nested container with ModuleList
+ self.nested_container = NestedContainer(features)
+
+ # Another top-level computational layer
+ self.proj_out = torch.nn.Linear(features, features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj_in(x)
+ x = self.nested_container(x)
+ x = self.proj_out(x)
+ return x
+
+
+class NestedContainer(torch.nn.Module):
+ def __init__(self, features: int) -> None:
+ super().__init__()
+
+ self.blocks = torch.nn.ModuleList([torch.nn.Linear(features, features), torch.nn.Linear(features, features)])
+ self.norm = torch.nn.LayerNorm(features)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for block in self.blocks:
+ x = block(x)
+ x = self.norm(x)
+ return x
+
+
+@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
in_features = 64
hidden_features = 256
@@ -90,8 +236,8 @@ def tearDown(self):
del self.model
del self.input
gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
def get_model(self):
torch.manual_seed(0)
@@ -106,8 +252,8 @@ def test_offloading_forward_pass(self):
@torch.no_grad()
def run_forward(model):
gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
@@ -117,7 +263,7 @@ def run_forward(model):
)
model.eval()
output = model(self.input)[0].cpu()
- max_memory_allocated = torch.cuda.max_memory_allocated()
+ max_memory_allocated = backend_max_memory_allocated(torch_device)
return output, max_memory_allocated
self.model.to(torch_device)
@@ -152,10 +298,10 @@ def run_forward(model):
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
# Memory assertions - offloading should reduce memory usage
- self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)
+ self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
- def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
- if torch.device(torch_device).type != "cuda":
+ def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.models.modeling_utils")
@@ -164,8 +310,8 @@ def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
self.model.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
- def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
- if torch.device(torch_device).type != "cuda":
+ def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
pipe = DummyPipeline(self.model)
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
@@ -175,19 +321,20 @@ def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
pipe.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
- def test_error_raised_if_streams_used_and_no_cuda_device(self):
- original_is_available = torch.cuda.is_available
- torch.cuda.is_available = lambda: False
+ def test_error_raised_if_streams_used_and_no_accelerator_device(self):
+ torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
+ original_is_available = torch_accelerator_module.is_available
+ torch_accelerator_module.is_available = lambda: False
with self.assertRaises(ValueError):
self.model.enable_group_offload(
- onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True
+ onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
)
- torch.cuda.is_available = original_is_available
+ torch_accelerator_module.is_available = original_is_available
def test_error_raised_if_supports_group_offloading_false(self):
self.model._supports_group_offloading = False
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
- self.model.enable_group_offload(onload_device=torch.device("cuda"))
+ self.model.enable_group_offload(onload_device=torch.device(torch_device))
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
@@ -212,3 +359,210 @@ def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module
pipe.enable_sequential_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
+
+ def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
+ return
+
+ model = DummyModelWithMultipleBlocks(
+ in_features=self.in_features,
+ hidden_features=self.hidden_features,
+ out_features=self.out_features,
+ num_layers=self.num_layers,
+ num_single_layers=self.num_layers + 1,
+ )
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
+
+ context = contextlib.nullcontext()
+ if compare_versions("diffusers", "<=", "0.33.0"):
+ # Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
+ context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device")
+
+ with context:
+ model(self.input)
+
+ @parameterized.expand([("block_level",), ("leaf_level",)])
+ def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
+ return
+
+ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm):
+ for name, module in model.named_modules():
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ hook = LayerOutputTrackerHook()
+ registry.register_hook(hook, "layer_output_tracker")
+
+ model_ref = DummyModelWithLayerNorm(128, 256, 128, 2)
+ model = DummyModelWithLayerNorm(128, 256, 128, 2)
+
+ model.load_state_dict(model_ref.state_dict(), strict=True)
+
+ model_ref.to(torch_device)
+ model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
+
+ apply_layer_output_tracker_hook(model_ref)
+ apply_layer_output_tracker_hook(model)
+
+ x = torch.randn(2, 128).to(torch_device)
+
+ out_ref = model_ref(x)
+ out = model(x)
+ self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
+
+ num_repeats = 2
+ for i in range(num_repeats):
+ out_ref = model_ref(x)
+ out = model(x)
+
+ self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
+
+ for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
+ assert ref_name == name
+ ref_outputs = (
+ HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
+ )
+ outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
+ cumulated_absmax = 0.0
+ for i in range(len(outputs)):
+ diff = ref_outputs[0] - outputs[i]
+ absdiff = diff.abs()
+ absmax = absdiff.max().item()
+ cumulated_absmax += absmax
+ self.assertLess(
+ cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
+ )
+
+ def test_vae_like_model_without_streams(self):
+ """Test VAE-like model with block-level offloading but without streams."""
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
+ return
+
+ config = self.get_autoencoder_kl_config()
+ model = AutoencoderKL(**config)
+
+ model_ref = AutoencoderKL(**config)
+ model_ref.load_state_dict(model.state_dict(), strict=True)
+ model_ref.to(torch_device)
+
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=False)
+
+ x = torch.randn(2, 3, 32, 32).to(torch_device)
+
+ with torch.no_grad():
+ out_ref = model_ref(x).sample
+ out = model(x).sample
+
+ self.assertTrue(
+ torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
+ )
+
+ def test_model_with_only_standalone_layers(self):
+ """Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
+ return
+
+ model = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
+
+ model_ref = DummyModelWithStandaloneLayers(in_features=64, hidden_features=128, out_features=64)
+ model_ref.load_state_dict(model.state_dict(), strict=True)
+ model_ref.to(torch_device)
+
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
+
+ x = torch.randn(2, 64).to(torch_device)
+
+ with torch.no_grad():
+ for i in range(2):
+ out_ref = model_ref(x)
+ out = model(x)
+ self.assertTrue(
+ torch.allclose(out_ref, out, atol=1e-5),
+ f"Outputs do not match at iteration {i} for model with standalone layers.",
+ )
+
+ @parameterized.expand([("block_level",), ("leaf_level",)])
+ def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
+ """Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
+ return
+
+ config = self.get_autoencoder_kl_config()
+ model = AutoencoderKL(**config)
+
+ model_ref = AutoencoderKL(**config)
+ model_ref.load_state_dict(model.state_dict(), strict=True)
+ model_ref.to(torch_device)
+
+ model.enable_group_offload(torch_device, offload_type=offload_type, num_blocks_per_group=1, use_stream=True)
+
+ x = torch.randn(2, 3, 32, 32).to(torch_device)
+
+ with torch.no_grad():
+ out_ref = model_ref(x).sample
+ out = model(x).sample
+
+ self.assertTrue(
+ torch.allclose(out_ref, out, atol=1e-5),
+ f"Outputs do not match for standalone Conv layers with {offload_type}.",
+ )
+
+ def test_multiple_invocations_with_vae_like_model(self):
+ """Test that multiple forward passes work correctly with VAE-like model."""
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
+ return
+
+ config = self.get_autoencoder_kl_config()
+ model = AutoencoderKL(**config)
+
+ model_ref = AutoencoderKL(**config)
+ model_ref.load_state_dict(model.state_dict(), strict=True)
+ model_ref.to(torch_device)
+
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
+
+ x = torch.randn(2, 3, 32, 32).to(torch_device)
+
+ with torch.no_grad():
+ for i in range(2):
+ out_ref = model_ref(x).sample
+ out = model(x).sample
+ self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
+
+ def test_nested_container_parameters_offloading(self):
+ """Test that parameters from non-computational layers in nested containers are handled correctly."""
+ if torch.device(torch_device).type not in ["cuda", "xpu"]:
+ return
+
+ model = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
+
+ model_ref = DummyModelWithDeeplyNestedBlocks(in_features=64, hidden_features=128, out_features=64)
+ model_ref.load_state_dict(model.state_dict(), strict=True)
+ model_ref.to(torch_device)
+
+ model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True)
+
+ x = torch.randn(2, 64).to(torch_device)
+
+ with torch.no_grad():
+ for i in range(2):
+ out_ref = model_ref(x)
+ out = model(x)
+ self.assertTrue(
+ torch.allclose(out_ref, out, atol=1e-5),
+ f"Outputs do not match at iteration {i} for nested parameters.",
+ )
+
+ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
+ init_dict = {
+ "block_out_channels": block_out_channels,
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels),
+ "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels),
+ "latent_channels": 4,
+ "norm_num_groups": norm_num_groups,
+ "layers_per_block": 1,
+ }
+ return init_dict
diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py
index 74bd43c52315..8a83f60ff278 100644
--- a/tests/hooks/test_hooks.py
+++ b/tests/hooks/test_hooks.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,8 @@
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger
-from diffusers.utils.testing_utils import CaptureLogger, torch_device
+
+from ..testing_utils import CaptureLogger, torch_device
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -168,9 +169,7 @@ def test_hook_registry(self):
registry.register_hook(MultiplyHook(2), "multiply_hook")
registry_repr = repr(registry)
- expected_repr = (
- "HookRegistry(\n" " (0) add_hook - AddHook\n" " (1) multiply_hook - MultiplyHook(value=2)\n" ")"
- )
+ expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
@@ -221,6 +220,7 @@ def test_inference(self):
self.assertAlmostEqual(output1, output2, places=5)
self.assertAlmostEqual(output1, output3, places=5)
+ self.assertAlmostEqual(output2, output3, places=5)
def test_skip_layer_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -285,12 +285,7 @@ def test_invocation_order_stateful_first(self):
self.model(input)
output = cap_logger.out.replace(" ", "").replace("\n", "")
expected_invocation_order_log = (
- (
- "MultiplyHook pre_forward\n"
- "AddHook pre_forward\n"
- "AddHook post_forward\n"
- "MultiplyHook post_forward\n"
- )
+ ("MultiplyHook pre_forward\nAddHook pre_forward\nAddHook post_forward\nMultiplyHook post_forward\n")
.replace(" ", "")
.replace("\n", "")
)
diff --git a/tests/pipelines/blipdiffusion/__init__.py b/tests/lora/__init__.py
similarity index 100%
rename from tests/pipelines/blipdiffusion/__init__.py
rename to tests/lora/__init__.py
diff --git a/tests/lora/test_deprecated_utilities.py b/tests/lora/test_deprecated_utilities.py
deleted file mode 100644
index 4275ef8089a3..000000000000
--- a/tests/lora/test_deprecated_utilities.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import os
-import tempfile
-import unittest
-
-import torch
-
-from diffusers.loaders.lora_base import LoraBaseMixin
-
-
-class UtilityMethodDeprecationTests(unittest.TestCase):
- def test_fetch_state_dict_cls_method_raises_warning(self):
- state_dict = torch.nn.Linear(3, 3).state_dict()
- with self.assertWarns(FutureWarning) as warning:
- _ = LoraBaseMixin._fetch_state_dict(
- state_dict,
- weight_name=None,
- use_safetensors=False,
- local_files_only=True,
- cache_dir=None,
- force_download=False,
- proxies=None,
- token=None,
- revision=None,
- subfolder=None,
- user_agent=None,
- allow_pickle=None,
- )
- warning_message = str(warning.warnings[0].message)
- assert "Using the `_fetch_state_dict()` method from" in warning_message
-
- def test_best_guess_weight_name_cls_method_raises_warning(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- state_dict = torch.nn.Linear(3, 3).state_dict()
- torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
-
- with self.assertWarns(FutureWarning) as warning:
- _ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
- warning_message = str(warning.warnings[0].message)
- assert "Using the `_best_guess_weight_name()` method from" in warning_message
diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py
new file mode 100644
index 000000000000..91f63c4b56c4
--- /dev/null
+++ b/tests/lora/test_lora_layers_auraflow.py
@@ -0,0 +1,136 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from diffusers import (
+ AuraFlowPipeline,
+ AuraFlowTransformer2DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+
+from ..testing_utils import (
+ floats_tensor,
+ is_peft_available,
+ require_peft_backend,
+)
+
+
+if is_peft_available():
+ pass
+
+sys.path.append(".")
+
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = AuraFlowPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "sample_size": 64,
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_mmdit_layers": 1,
+ "num_single_dit_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "caption_projection_dim": 32,
+ "pos_embed_max_size": 64,
+ }
+ transformer_cls = AuraFlowTransformer2DModel
+ vae_kwargs = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "block_out_channels": (4,),
+ "layers_per_block": 1,
+ "latent_channels": 4,
+ "norm_num_groups": 1,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ "shift_factor": 0.0609,
+ "scaling_factor": 1.5035,
+ }
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5"
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+ denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
+
+ @property
+ def output_shape(self):
+ return (1, 8, 8, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 10
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "num_inference_steps": 4,
+ "guidance_scale": 0.0,
+ "height": 8,
+ "width": 8,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ @unittest.skip("Not supported in AuraFlow.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in AuraFlow.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in AuraFlow.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py
index dc2695452c2f..fa57b4c9c2f9 100644
--- a/tests/lora/test_lora_layers_cogvideox.py
+++ b/tests/lora/test_lora_layers_cogvideox.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,24 +16,26 @@
import unittest
import torch
+from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLCogVideoX,
- CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
floats_tensor,
require_peft_backend,
+ require_torch_accelerator,
)
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests # noqa: E402
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -41,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}
- scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = {
"num_attention_heads": 4,
@@ -124,6 +125,16 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+ def test_lora_scale_kwargs_match_fusion(self):
+ super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
+
+ @parameterized.expand([("block_level", True), ("leaf_level", False)])
+ @require_torch_accelerator
+ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
+ # TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
+ # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
+ super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
+
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py
index 178de2069b7e..30eb8fbb6367 100644
--- a/tests/lora/test_lora_layers_cogview4.py
+++ b/tests/lora/test_lora_layers_cogview4.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,15 +18,23 @@
import numpy as np
import torch
+from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
-from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
+
+from ..testing_utils import (
+ floats_tensor,
+ require_peft_backend,
+ require_torch_accelerator,
+ skip_mps,
+ torch_device,
+)
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests # noqa: E402
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
class TokenizerWrapper:
@@ -42,7 +50,6 @@ def from_pretrained(*args, **kwargs):
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -116,30 +123,33 @@ def test_simple_inference_save_pretrained(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
- for scheduler_cls in self.scheduler_classes:
- components, _, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
+ pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
+ pipe_from_pretrained.to(torch_device)
- pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
- pipe_from_pretrained.to(torch_device)
+ images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
- images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
- self.assertTrue(
- np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ @parameterized.expand([("block_level", True), ("leaf_level", False)])
+ @require_torch_accelerator
+ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
+ # TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
+ # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
+ super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
@unittest.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self):
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
index 860aa6511689..b840d7ac72ce 100644
--- a/tests/lora/test_lora_layers_flux.py
+++ b/tests/lora/test_lora_layers_flux.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,6 @@
import unittest
import numpy as np
-import pytest
import safetensors.torch
import torch
from parameterized import parameterized
@@ -29,15 +28,17 @@
from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxPipeline, FluxTransformer2DModel
from diffusers.utils import load_image, logging
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
CaptureLogger,
+ backend_empty_cache,
floats_tensor,
is_peft_available,
nightly,
numpy_cosine_similarity_distance,
- require_big_gpu_with_torch_cuda,
+ require_big_accelerator,
require_peft_backend,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -48,15 +49,14 @@
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
- scheduler_cls = FlowMatchEulerDiscreteScheduler()
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
@@ -122,9 +122,6 @@ def test_with_alpha_in_state_dict(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
@@ -170,8 +167,7 @@ def test_lora_expansion_works_for_absent_keys(self):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -218,9 +214,7 @@ def test_lora_expansion_works_for_extra_keys(self):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -281,9 +275,8 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
- scheduler_cls = FlowMatchEulerDiscreteScheduler()
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 8,
@@ -330,6 +323,7 @@ def get_dummy_inputs(self, with_generator=True):
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+ np.random.seed(0)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
@@ -809,10 +803,9 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
@slow
@nightly
-@require_torch_gpu
+@require_torch_accelerator
@require_peft_backend
-@require_big_gpu_with_torch_cuda
-@pytest.mark.big_gpu_with_torch_cuda
+@require_big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
@@ -827,7 +820,7 @@ def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
@@ -836,13 +829,13 @@ def tearDown(self):
del self.pipeline
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
- # Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI
+ # Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
self.pipeline = self.pipeline.to(torch_device)
@@ -907,6 +900,13 @@ def test_flux_kohya_with_text_encoder(self):
assert max_diff < 1e-3
+ def test_flux_kohya_embedders_conversion(self):
+ """Test that embedders load without throwing errors"""
+ self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
+ self.pipeline.unload_lora_weights()
+
+ assert True
+
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
@@ -956,10 +956,9 @@ def test_flux_xlabs_load_lora_with_single_blocks(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
@require_peft_backend
-@require_big_gpu_with_torch_cuda
-@pytest.mark.big_gpu_with_torch_cuda
+@require_big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
@@ -969,17 +968,17 @@ def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
self.pipeline = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
- ).to("cuda")
+ ).to(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora(self, lora_ckpt_id):
diff --git a/tests/lora/test_lora_layers_flux2.py b/tests/lora/test_lora_layers_flux2.py
new file mode 100644
index 000000000000..4ae189aceb66
--- /dev/null
+++ b/tests/lora/test_lora_layers_flux2.py
@@ -0,0 +1,168 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoProcessor, Mistral3ForConditionalGeneration
+
+from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
+
+from ..testing_utils import floats_tensor, require_peft_backend, torch_device
+
+
+sys.path.append(".")
+
+from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+
+
+@require_peft_backend
+class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = Flux2Pipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 16,
+ "timestep_guidance_channels": 256,
+ "axes_dims_rope": [4, 4, 4, 4],
+ }
+ transformer_cls = Flux2Transformer2DModel
+ vae_kwargs = {
+ "sample_size": 32,
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ("DownEncoderBlock2D",),
+ "up_block_types": ("UpDecoderBlock2D",),
+ "block_out_channels": (4,),
+ "layers_per_block": 1,
+ "latent_channels": 1,
+ "norm_num_groups": 1,
+ "use_quant_conv": False,
+ "use_post_quant_conv": False,
+ }
+ vae_cls = AutoencoderKLFlux2
+
+ tokenizer_cls, tokenizer_id = AutoProcessor, "hf-internal-testing/tiny-mistral3-diffusers"
+ text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
+ denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
+
+ @property
+ def output_shape(self):
+ return (1, 8, 8, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 10
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "a dog is dancing",
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 8,
+ "output_type": "np",
+ "text_encoder_out_layers": (1,),
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ # Overriding because (1) text encoder LoRAs are not supported in Flux 2 and (2) because the Flux 2 single block
+ # QKV projections are always fused, it has no `to_q` param as expected by the original test.
+ def test_lora_fuse_nan(self):
+ components, _, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ possible_tower_names = ["transformer_blocks", "single_transformer_blocks"]
+ filtered_tower_names = [
+ tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
+ ]
+ if len(filtered_tower_names) == 0:
+ reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
+ raise ValueError(reason)
+ for tower_name in filtered_tower_names:
+ transformer_tower = getattr(pipe.transformer, tower_name)
+ is_single = "single" in tower_name
+ if is_single:
+ transformer_tower[0].attn.to_qkv_mlp_proj.lora_A["adapter-1"].weight += float("inf")
+ else:
+ transformer_tower[0].attn.to_k.lora_A["adapter-1"].weight += float("inf")
+
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
+
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
+ out = pipe(**inputs)[0]
+
+ self.assertTrue(np.isnan(out).all())
+
+ @unittest.skip("Not supported in Flux2.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Flux2.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Flux2.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Flux2.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
index d2015d8b0711..cfd5d3146a91 100644
--- a/tests/lora/test_lora_layers_hunyuanvideo.py
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
import unittest
import numpy as np
-import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -27,20 +26,24 @@
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
+ Expectations,
+ backend_empty_cache,
floats_tensor,
nightly,
numpy_cosine_similarity_distance,
- require_big_gpu_with_torch_cuda,
+ require_big_accelerator,
require_peft_backend,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
+ torch_device,
)
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests # noqa: E402
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -48,7 +51,6 @@
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -192,10 +194,9 @@ def test_simple_inference_with_text_lora_save_load(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
@require_peft_backend
-@require_big_gpu_with_torch_cuda
-@pytest.mark.big_gpu_with_torch_cuda
+@require_big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.
@@ -210,7 +211,7 @@ def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
@@ -218,13 +219,13 @@ def setUp(self):
)
self.pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.float16
- ).to("cuda")
+ ).to(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_original_format_cseti(self):
self.pipeline.load_lora_weights(
@@ -249,8 +250,14 @@ def test_original_format_cseti(self):
out_slice = np.concatenate((out[:8], out[-8:]))
# fmt: off
- expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
+ expected_slices = Expectations(
+ {
+ ("cuda", 7): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
+ ("xpu", 3): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
+ }
+ )
# fmt: on
+ expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py
index 0eccaa73ad42..6ab51a5e513f 100644
--- a/tests/lora/test_lora_layers_ltx_video.py
+++ b/tests/lora/test_lora_layers_ltx_video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,19 +24,19 @@
LTXPipeline,
LTXVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
+
+from ..testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests # noqa: E402
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py
index 07b1cda2f79f..0417b05b33a1 100644
--- a/tests/lora/test_lora_layers_lumina2.py
+++ b/tests/lora/test_lora_layers_lumina2.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,22 +23,22 @@
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
- Lumina2Text2ImgPipeline,
+ Lumina2Pipeline,
Lumina2Transformer2DModel,
)
-from diffusers.utils.testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device
+
+from ..testing_utils import floats_tensor, is_torch_version, require_peft_backend, skip_mps, torch_device
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
@require_peft_backend
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
- pipeline_class = Lumina2Text2ImgPipeline
+ pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -140,33 +140,30 @@ def test_simple_inference_with_text_lora_save_load(self):
strict=False,
)
def test_lora_fuse_nan(self):
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
-
- # corrupt one LoRA weight with `inf` values
- with torch.no_grad():
- pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
-
- # with `safe_fusing=True` we should see an Error
- with self.assertRaises(ValueError):
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
-
- # without we should not see an error, but every image will be black
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
- out = pipe(**inputs)[0]
-
- self.assertTrue(np.isnan(out).all())
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
+
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
+
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
+ out = pipe(**inputs)[0]
+
+ self.assertTrue(np.isnan(out).all())
diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py
index 671f1277f99f..7be81273db77 100644
--- a/tests/lora/test_lora_layers_mochi.py
+++ b/tests/lora/test_lora_layers_mochi.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +19,8 @@
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
floats_tensor,
require_peft_backend,
skip_mps,
@@ -28,7 +29,7 @@
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests # noqa: E402
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -36,7 +37,6 @@
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py
new file mode 100644
index 000000000000..51de2f8e20e1
--- /dev/null
+++ b/tests/lora/test_lora_layers_qwenimage.py
@@ -0,0 +1,129 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImagePipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ..testing_utils import floats_tensor, require_peft_backend
+
+
+sys.path.append(".")
+
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = QwenImagePipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": 2,
+ "in_channels": 16,
+ "out_channels": 4,
+ "num_layers": 2,
+ "attention_head_dim": 16,
+ "num_attention_heads": 3,
+ "joint_attention_dim": 16,
+ "guidance_embeds": False,
+ "axes_dims_rope": (8, 4, 4),
+ }
+ transformer_cls = QwenImageTransformer2DModel
+ z_dim = 4
+ vae_kwargs = {
+ "base_dim": z_dim * 6,
+ "z_dim": z_dim,
+ "dim_mult": [1, 2, 4],
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True],
+ "latents_mean": [0.0] * 4,
+ "latents_std": [1.0] * 4,
+ }
+ vae_cls = AutoencoderKLQwenImage
+ tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen25VLForCondGen"
+ text_encoder_cls, text_encoder_id = (
+ Qwen2_5_VLForConditionalGeneration,
+ "hf-internal-testing/tiny-random-Qwen25VLForCondGen",
+ )
+ denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+
+ @property
+ def output_shape(self):
+ return (1, 8, 8, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 10
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "num_inference_steps": 4,
+ "guidance_scale": 0.0,
+ "height": 8,
+ "width": 8,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ @unittest.skip("Not supported in Qwen Image.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Qwen Image.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Qwen Image.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py
index 78f71527cb7e..a860b7b44f2c 100644
--- a/tests/lora/test_lora_layers_sana.py
+++ b/tests/lora/test_lora_layers_sana.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,20 +19,20 @@
from transformers import Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
-from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
+
+from ..testing_utils import IS_GITHUB_ACTIONS, floats_tensor, require_peft_backend
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests # noqa: E402
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline
- scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
- scheduler_kwargs = {}
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {"shift": 7.0}
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
@@ -136,3 +136,7 @@ def test_simple_inference_with_text_lora_fused(self):
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference_denoiser(self):
+ return super().test_layerwise_casting_inference_denoiser()
diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py
index 3eefa97663e6..933bf2336a59 100644
--- a/tests/lora/test_lora_layers_sd.py
+++ b/tests/lora/test_lora_layers_sd.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,7 +32,9 @@
StableDiffusionPipeline,
)
from diffusers.utils.import_utils import is_accelerate_available
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
+ Expectations,
backend_empty_cache,
load_image,
nightly,
@@ -46,7 +48,7 @@
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
if is_accelerate_available():
@@ -92,12 +94,12 @@ def output_shape(self):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
# Keeping this test here makes sense because it doesn't look any integration
# (value assertions on logits).
@@ -119,7 +121,7 @@ def test_integration_move_lora_cpu(self):
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
- "Lora not correctly set in text encoder",
+ "Lora not correctly set in unet",
)
# We will offload the first adapter in CPU and check if the offloading
@@ -186,7 +188,7 @@ def test_integration_move_lora_dora_cpu(self):
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
- "Lora not correctly set in text encoder",
+ "Lora not correctly set in unet",
)
for name, param in pipe.unet.named_parameters():
@@ -207,6 +209,53 @@ def test_integration_move_lora_dora_cpu(self):
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))
+ @slow
+ @require_torch_accelerator
+ def test_integration_set_lora_device_different_target_layers(self):
+ # fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
+ # layers, see #11833
+ from peft import LoraConfig
+
+ path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+ pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
+ # configs partly target the same, partly different layers
+ config0 = LoraConfig(target_modules=["to_k", "to_v"])
+ config1 = LoraConfig(target_modules=["to_k", "to_q"])
+ pipe.unet.add_adapter(config0, adapter_name="adapter-0")
+ pipe.unet.add_adapter(config1, adapter_name="adapter-1")
+ pipe = pipe.to(torch_device)
+
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.unet),
+ "Lora not correctly set in unet",
+ )
+
+ # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
+ modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
+ modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
+ self.assertNotEqual(modules_adapter_0, modules_adapter_1)
+ self.assertTrue(modules_adapter_0 - modules_adapter_1)
+ self.assertTrue(modules_adapter_1 - modules_adapter_0)
+
+ # setting both separately works
+ pipe.set_lora_device(["adapter-0"], "cpu")
+ pipe.set_lora_device(["adapter-1"], "cpu")
+
+ for name, module in pipe.unet.named_modules():
+ if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device == torch.device("cpu"))
+ elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device == torch.device("cpu"))
+
+ # setting both at once also works
+ pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
+
+ for name, module in pipe.unet.named_modules():
+ if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device != torch.device("cpu"))
+ elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device != torch.device("cpu"))
+
@slow
@nightly
@@ -455,11 +504,54 @@ def test_vanilla_funetuning(self):
images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images
- images = images[0, -3:, -3:, -1].flatten()
-
- expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
+ image_slice = images[0, -3:, -3:, -1].flatten()
+
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.6544,
+ 0.6127,
+ 0.5397,
+ 0.6845,
+ 0.6047,
+ 0.5469,
+ 0.6349,
+ 0.5906,
+ 0.5382,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.7406,
+ 0.699,
+ 0.5963,
+ 0.7493,
+ 0.7045,
+ 0.6096,
+ 0.6886,
+ 0.6388,
+ 0.583,
+ ]
+ ),
+ ("cuda", 8): np.array(
+ [
+ 0.6542,
+ 0.61253,
+ 0.5396,
+ 0.6843,
+ 0.6044,
+ 0.5468,
+ 0.6349,
+ 0.5905,
+ 0.5381,
+ ]
+ ),
+ }
+ )
+ expected_slice = expected_slices.get_expectation()
- max_diff = numpy_cosine_similarity_distance(expected, images)
+ max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice)
assert max_diff < 1e-4
pipe.unload_lora_weights()
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index 90aaa3bcfe78..228460eaad90 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -29,12 +28,13 @@
)
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
is_flaky,
nightly,
numpy_cosine_similarity_distance,
- require_big_gpu_with_torch_cuda,
+ require_big_accelerator,
require_peft_backend,
require_torch_accelerator,
torch_device,
@@ -43,7 +43,7 @@
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests # noqa: E402
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
if is_accelerate_available():
@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
@@ -138,8 +137,7 @@ def test_multiple_wrong_adapter_name_raises_error(self):
@nightly
@require_torch_accelerator
@require_peft_backend
-@require_big_gpu_with_torch_cuda
-@pytest.mark.big_gpu_with_torch_cuda
+@require_big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py
index 76d6dc48602b..ac1d65abdaa7 100644
--- a/tests/lora/test_lora_layers_sdxl.py
+++ b/tests/lora/test_lora_layers_sdxl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,14 +35,16 @@
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
CaptureLogger,
+ backend_empty_cache,
is_flaky,
load_image,
nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
@@ -50,7 +52,7 @@
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set, state_dicts_almost_equal # noqa: E402
+from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set, state_dicts_almost_equal # noqa: E402
if is_accelerate_available():
@@ -105,32 +107,66 @@ def output_shape(self):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ if torch.cuda.is_available():
+ expected_atol = 9e-2
+ expected_rtol = 9e-2
+ else:
+ expected_atol = 1e-3
+ expected_rtol = 1e-3
+
+ super().test_simple_inference_with_text_denoiser_lora_unfused(
+ expected_atol=expected_atol, expected_rtol=expected_rtol
+ )
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ if torch.cuda.is_available():
+ expected_atol = 9e-2
+ expected_rtol = 9e-2
+ else:
+ expected_atol = 1e-3
+ expected_rtol = 1e-3
+
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(
+ expected_atol=expected_atol, expected_rtol=expected_rtol
+ )
+
+ def test_lora_scale_kwargs_match_fusion(self):
+ if torch.cuda.is_available():
+ expected_atol = 9e-2
+ expected_rtol = 9e-2
+ else:
+ expected_atol = 1e-3
+ expected_rtol = 1e-3
+
+ super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
+
@slow
@nightly
-@require_torch_gpu
+@require_torch_accelerator
@require_peft_backend
class LoraSDXLIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_sdxl_1_0_lora(self):
generator = torch.Generator("cpu").manual_seed(0)
diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py
index c2498fa68c3d..5734509b410f 100644
--- a/tests/lora/test_lora_layers_wan.py
+++ b/tests/lora/test_lora_layers_wan.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,8 @@
WanPipeline,
WanTransformer3DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
floats_tensor,
require_peft_backend,
skip_mps,
@@ -33,7 +34,7 @@
sys.path.append(".")
-from utils import PeftLoraLoaderMixinTests # noqa: E402
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@@ -41,7 +42,6 @@
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py
new file mode 100644
index 000000000000..ab1f57bfc9da
--- /dev/null
+++ b/tests/lora/test_lora_layers_wanvace.py
@@ -0,0 +1,215 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import tempfile
+import unittest
+
+import numpy as np
+import safetensors.torch
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
+from diffusers.utils.import_utils import is_peft_available
+
+from ..testing_utils import (
+ floats_tensor,
+ is_flaky,
+ require_peft_backend,
+ require_peft_version_greater,
+ skip_mps,
+ torch_device,
+)
+
+
+if is_peft_available():
+ from peft.utils import get_peft_model_state_dict
+
+sys.path.append(".")
+
+from .utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+@is_flaky(max_attempts=10, description="very flaky class")
+class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = WanVACEPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 32,
+ "freq_dim": 16,
+ "ffn_dim": 16,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 16,
+ "vace_layers": [0],
+ "vace_in_channels": 72,
+ }
+ transformer_cls = WanVACETransformer3DModel
+ vae_kwargs = {
+ "base_dim": 3,
+ "z_dim": 4,
+ "dim_mult": [1, 1, 1, 1],
+ "latents_mean": torch.randn(4).numpy().tolist(),
+ "latents_std": torch.randn(4).numpy().tolist(),
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+ vae_cls = AutoencoderKLWan
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+
+ @property
+ def output_shape(self):
+ return (1, 9, 16, 16, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ sizes = (4, 4)
+ height, width = 16, 16
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+ video = [Image.new("RGB", (height, width))] * num_frames
+ mask = [Image.new("L", (height, width), 0)] * num_frames
+
+ pipeline_inputs = {
+ "video": video,
+ "mask": mask,
+ "prompt": "",
+ "num_frames": num_frames,
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": height,
+ "width": height,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
+
+ def test_layerwise_casting_inference_denoiser(self):
+ super().test_layerwise_casting_inference_denoiser()
+
+ @require_peft_version_greater("0.13.2")
+ def test_lora_exclude_modules_wanvace(self):
+ exclude_module_name = "vace_blocks.0.proj_out"
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = self.get_base_pipe_output()
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # only supported for `denoiser` now
+ denoiser_lora_config.target_modules = ["proj_out"]
+ denoiser_lora_config.exclude_modules = [exclude_module_name]
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ # The state dict shouldn't contain the modules to be excluded from LoRA.
+ state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
+ self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
+ self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
+ output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
+ pipe.unload_lora_weights()
+
+ # Check in the loaded state dict.
+ loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
+ self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
+
+ # Check in the state dict obtained after loading LoRA.
+ pipe.load_lora_weights(tmpdir)
+ state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
+ self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
+ self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
+
+ output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
+ "LoRA should change outputs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
+ "Lora outputs should match.",
+ )
+
+ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
+ super().test_simple_inference_with_text_denoiser_lora_and_scale()
diff --git a/tests/lora/test_lora_layers_z_image.py b/tests/lora/test_lora_layers_z_image.py
new file mode 100644
index 000000000000..35d1389d9612
--- /dev/null
+++ b/tests/lora/test_lora_layers_z_image.py
@@ -0,0 +1,285 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
+
+from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device
+
+
+if is_peft_available():
+ from peft import LoraConfig
+
+
+sys.path.append(".")
+
+from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
+
+
+@require_peft_backend
+class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = ZImagePipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "all_patch_size": (2,),
+ "all_f_patch_size": (1,),
+ "in_channels": 16,
+ "dim": 32,
+ "n_layers": 2,
+ "n_refiner_layers": 1,
+ "n_heads": 2,
+ "n_kv_heads": 2,
+ "norm_eps": 1e-5,
+ "qk_norm": True,
+ "cap_feat_dim": 16,
+ "rope_theta": 256.0,
+ "t_scale": 1000.0,
+ "axes_dims": [8, 4, 4],
+ "axes_lens": [256, 32, 32],
+ }
+ transformer_cls = ZImageTransformer2DModel
+ vae_kwargs = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ "block_out_channels": [32, 64],
+ "layers_per_block": 1,
+ "latent_channels": 16,
+ "norm_num_groups": 32,
+ "sample_size": 32,
+ "scaling_factor": 0.3611,
+ "shift_factor": 0.1159,
+ }
+ vae_cls = AutoencoderKL
+ tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
+ text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
+ denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+
+ @property
+ def output_shape(self):
+ return (1, 32, 32, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 10
+ num_channels = 4
+ sizes = (32, 32)
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+
+ pipeline_inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "num_inference_steps": 4,
+ "guidance_scale": 0.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
+ # Override to create Qwen3Model inline since it doesn't have a pretrained tiny model
+ torch.manual_seed(0)
+ config = Qwen3Config(
+ hidden_size=16,
+ intermediate_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ vocab_size=151936,
+ max_position_embeddings=512,
+ )
+ text_encoder = Qwen3Model(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id)
+
+ transformer = self.transformer_cls(**self.transformer_kwargs)
+ # `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
+ # This can cause NaN data values in our testing environment. Fixating them
+ # helps prevent that issue.
+ with torch.no_grad():
+ transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
+ transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
+ vae = self.vae_cls(**self.vae_kwargs)
+
+ if scheduler_cls is None:
+ scheduler_cls = self.scheduler_cls
+ scheduler = scheduler_cls(**self.scheduler_kwargs)
+
+ rank = 4
+ lora_alpha = rank if lora_alpha is None else lora_alpha
+
+ text_lora_config = LoraConfig(
+ r=rank,
+ lora_alpha=lora_alpha,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+ init_lora_weights=False,
+ use_dora=use_dora,
+ )
+
+ denoiser_lora_config = LoraConfig(
+ r=rank,
+ lora_alpha=lora_alpha,
+ target_modules=self.denoiser_target_modules,
+ init_lora_weights=False,
+ use_dora=use_dora,
+ )
+
+ pipeline_components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+
+ return pipeline_components, text_lora_config, denoiser_lora_config
+
+ def test_correct_lora_configs_with_different_ranks(self):
+ components, _, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+
+ lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ pipe.transformer.delete_adapters("adapter-1")
+
+ denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
+ for name, _ in denoiser.named_modules():
+ if "to_k" in name and "attention" in name and "lora" not in name:
+ module_name_to_rank_update = name.replace(".base_layer.", ".")
+ break
+
+ # change the rank_pattern
+ updated_rank = denoiser_lora_config.r * 2
+ denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
+
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+ updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
+
+ self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
+
+ lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
+ self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
+
+ pipe.transformer.delete_adapters("adapter-1")
+
+ # similarly change the alpha_pattern
+ updated_alpha = denoiser_lora_config.lora_alpha * 2
+ denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
+
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(
+ pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
+ )
+
+ lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
+ self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
+
+ @skip_mps
+ def test_lora_fuse_nan(self):
+ components, _, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ possible_tower_names = ["noise_refiner"]
+ filtered_tower_names = [
+ tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
+ ]
+ for tower_name in filtered_tower_names:
+ transformer_tower = getattr(pipe.transformer, tower_name)
+ transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")
+
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
+
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
+ out = pipe(**inputs)[0]
+
+ self.assertTrue(np.isnan(out).all())
+
+ def test_lora_scale_kwargs_match_fusion(self):
+ super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2)
+
+ @unittest.skip("Needs to be debugged.")
+ def test_set_adapters_match_attention_kwargs(self):
+ super().test_set_adapters_match_attention_kwargs()
+
+ @unittest.skip("Needs to be debugged.")
+ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
+ super().test_simple_inference_with_text_denoiser_lora_and_scale()
+
+ @unittest.skip("Not supported in ZImage.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in ZImage.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in ZImage.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in ZImage.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in ZImage.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in ZImage.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in ZImage.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in ZImage.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index 8cdb43c9d085..5fae6cac0a7f 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,21 +22,24 @@
import numpy as np
import pytest
import torch
+from parameterized import parameterized
from diffusers import (
AutoencoderKL,
- DDIMScheduler,
- LCMScheduler,
UNet2DConditionModel,
)
+from diffusers.hooks.group_offloading import _GROUP_OFFLOADING, apply_group_offloading
from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
CaptureLogger,
+ check_if_dicts_are_equal,
floats_tensor,
is_torch_version,
require_peft_backend,
require_peft_version_greater,
+ require_torch_accelerator,
require_transformers_version_greater,
skip_mps,
torch_device,
@@ -71,6 +74,13 @@ def check_if_lora_correctly_set(model) -> bool:
return False
+def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
+ extracted = {
+ k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
+ }
+ check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])
+
+
def initialize_dummy_state_dict(state_dict):
if not all(v.device.type == "meta" for _, v in state_dict.items()):
raise ValueError("`state_dict` has non-meta values.")
@@ -80,13 +90,24 @@ def initialize_dummy_state_dict(state_dict):
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
+def determine_attention_kwargs_name(pipeline_class):
+ call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()
+
+ # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
+ for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
+ if possible_attention_kwargs in call_signature_keys:
+ attention_kwargs_name = possible_attention_kwargs
+ break
+ assert attention_kwargs_name is not None
+ return attention_kwargs_name
+
+
@require_peft_backend
class PeftLoraLoaderMixinTests:
pipeline_class = None
scheduler_cls = None
scheduler_kwargs = None
- scheduler_classes = [DDIMScheduler, LCMScheduler]
has_two_text_encoders = False
has_three_text_encoders = False
@@ -104,15 +125,24 @@ class PeftLoraLoaderMixinTests:
vae_kwargs = None
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
+ denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+
+ cached_non_lora_output = None
+
+ def get_base_pipe_output(self):
+ if self.cached_non_lora_output is None:
+ self.cached_non_lora_output = self._compute_baseline_output()
+ return self.cached_non_lora_output
- def get_dummy_components(self, scheduler_cls=None, use_dora=False):
+ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
- scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
+ scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls
rank = 4
+ lora_alpha = rank if lora_alpha is None else lora_alpha
torch.manual_seed(0)
if self.unet_kwargs is not None:
@@ -148,7 +178,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False):
text_lora_config = LoraConfig(
r=rank,
- lora_alpha=rank,
+ lora_alpha=lora_alpha,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=use_dora,
@@ -156,8 +186,8 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False):
denoiser_lora_config = LoraConfig(
r=rank,
- lora_alpha=rank,
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
+ lora_alpha=lora_alpha,
+ target_modules=self.denoiser_target_modules,
init_lora_weights=False,
use_dora=use_dora,
)
@@ -216,15 +246,16 @@ def get_dummy_inputs(self, with_generator=True):
return noise, input_ids, pipeline_inputs
- # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
- def get_dummy_tokens(self):
- max_seq_length = 77
-
- inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
+ def _compute_baseline_output(self):
+ components, _, _ = self.get_dummy_components(self.scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
- prepared_inputs = {}
- prepared_inputs["input_ids"] = inputs
- return prepared_inputs
+ # Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
+ # explicitly.
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ return pipe(**inputs, generator=torch.manual_seed(0))[0]
def _get_lora_state_dicts(self, modules_to_save):
state_dicts = {}
@@ -233,6 +264,13 @@ def _get_lora_state_dicts(self, modules_to_save):
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts
+ def _get_lora_adapter_metadata(self, modules_to_save):
+ metadatas = {}
+ for module_name, module in modules_to_save.items():
+ if module is not None:
+ metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
+ return metadatas
+
def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
@@ -260,375 +298,294 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
return modules_to_save
+ def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
+ if text_lora_config is not None:
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ )
+
+ if denoiser_lora_config is not None:
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ else:
+ denoiser = None
+
+ if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+ return pipe, denoiser
+
def test_simple_inference(self):
"""
Tests a simple inference and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- _, _, inputs = self.get_dummy_inputs()
- output_no_lora = pipe(**inputs)[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
+ assert output_no_lora.shape == self.output_shape
def test_simple_inference_with_text_lora(self):
"""
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ @require_peft_version_greater("0.13.1")
+ def test_low_cpu_mem_usage_with_injection(self):
+ """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
- output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
+ "The LoRA params should be on 'meta' device.",
)
- @require_peft_version_greater("0.13.1")
- def test_low_cpu_mem_usage_with_injection(self):
- """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
+ te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
+ set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
+ self.assertTrue(
+ "meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
+ "No param should be on 'meta' device.",
+ )
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ self.assertTrue(
+ "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
+ )
+
+ denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
+ set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
+ self.assertTrue(
+ "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
+ )
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue(
- "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
+ "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
"The LoRA params should be on 'meta' device.",
)
- te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
- set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
+ te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
+ set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
- "meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
+ "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
"No param should be on 'meta' device.",
)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- self.assertTrue(
- "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
- )
-
- denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
- set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
- self.assertTrue(
- "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
- )
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
- self.assertTrue(
- "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
- "The LoRA params should be on 'meta' device.",
- )
-
- te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
- set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
- self.assertTrue(
- "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
- "No param should be on 'meta' device.",
- )
-
- _, _, inputs = self.get_dummy_inputs()
- output_lora = pipe(**inputs)[0]
- self.assertTrue(output_lora.shape == self.output_shape)
+ _, _, inputs = self.get_dummy_inputs()
+ output_lora = pipe(**inputs)[0]
+ self.assertTrue(output_lora.shape == self.output_shape)
@require_peft_version_greater("0.13.1")
@require_transformers_version_greater("4.45.2")
def test_low_cpu_mem_usage_with_loading(self):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
- )
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
+ )
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
- # Now, check for `low_cpu_mem_usage.`
- pipe.unload_lora_weights()
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
+ # Now, check for `low_cpu_mem_usage.`
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
- images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(
- images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
- ),
- "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
- )
+ images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
+ )
def test_simple_inference_with_text_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
- call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
-
- # TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
- for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
- if possible_attention_kwargs in call_signature_keys:
- attention_kwargs_name = possible_attention_kwargs
- break
- assert attention_kwargs_name is not None
-
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ output_no_lora = self.get_base_pipe_output()
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
- output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
+ output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- self.assertTrue(
- not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
+ self.assertTrue(
+ not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
- attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
- output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
+ output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- self.assertTrue(
- np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
- "Lora + 0 scale should lead to same result as no LoRA",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
+ "Lora + 0 scale should lead to same result as no LoRA",
+ )
def test_simple_inference_with_text_lora_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ output_no_lora = self.get_base_pipe_output()
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- pipe.fuse_lora()
- # Fusing should still keep the LoRA layers
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ pipe.fuse_lora()
+ # Fusing should still keep the LoRA layers
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ output_no_lora = self.get_base_pipe_output()
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- pipe.unload_lora_weights()
- # unloading should remove the LoRA layers
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
- )
+ pipe.unload_lora_weights()
+ # unloading should remove the LoRA layers
+ self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly unloaded in text encoder 2",
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertFalse(
+ check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly unloaded in text encoder 2",
+ )
- ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
- "Fused lora should change the output",
- )
+ ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
+ "Fused lora should change the output",
+ )
def test_simple_inference_with_text_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
+ )
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
- )
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
-
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
def test_simple_inference_with_partial_text_lora(self):
"""
@@ -636,27 +593,27 @@ def test_simple_inference_with_partial_text_lora(self):
with different ranks and some adapters removed
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, _, _ = self.get_dummy_components(scheduler_cls)
- # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
- text_lora_config = LoraConfig(
- r=4,
- rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
- lora_alpha=4,
- target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
- init_lora_weights=False,
- use_dora=False,
- )
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, _ = self.get_dummy_components()
+ # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
+ text_lora_config = LoraConfig(
+ r=4,
+ rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
+ lora_alpha=4,
+ target_modules=self.text_encoder_target_modules,
+ init_lora_weights=False,
+ use_dora=False,
+ )
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
+
+ state_dict = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
# supports missing layers (PR#8324).
state_dict = {
@@ -665,309 +622,212 @@ def test_simple_inference_with_partial_text_lora(self):
if "text_model.encoder.layers.4" not in module_name
}
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
- state_dict.update(
- {
- f"text_encoder_2.{module_name}": param
- for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
- if "text_model.encoder.layers.4" not in module_name
- }
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ state_dict.update(
+ {
+ f"text_encoder_2.{module_name}": param
+ for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
+ if "text_model.encoder.layers.4" not in module_name
+ }
+ )
- output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- # Unload lora and load it back using the pipe.load_lora_weights machinery
- pipe.unload_lora_weights()
- pipe.load_lora_weights(state_dict)
+ # Unload lora and load it back using the pipe.load_lora_weights machinery
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(state_dict)
- output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
- "Removing adapters should change the output",
- )
+ output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
+ "Removing adapters should change the output",
+ )
- def test_simple_inference_save_pretrained(self):
+ def test_simple_inference_save_pretrained_with_text_lora(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
- pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
- pipe_from_pretrained.to(torch_device)
+ pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
+ pipe_from_pretrained.to(torch_device)
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
"Lora not correctly set in text encoder",
)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
- "Lora not correctly set in text encoder 2",
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
+ "Lora not correctly set in text encoder 2",
+ )
- images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
+ images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
def test_simple_inference_with_text_denoiser_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
- )
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
+ )
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
- call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
- for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
- if possible_attention_kwargs in call_signature_keys:
- attention_kwargs_name = possible_attention_kwargs
- break
- assert attention_kwargs_name is not None
-
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ output_no_lora = self.get_base_pipe_output()
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
+ output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
- output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
- self.assertTrue(
- not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
+ attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
+ output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
- output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ self.assertTrue(
+ np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
+ "Lora + 0 scale should lead to same result as no LoRA",
+ )
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
- np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
- "Lora + 0 scale should lead to same result as no LoRA",
+ pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
+ "The scaling parameter has not been correctly restored!",
)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
- "The scaling parameter has not been correctly restored!",
- )
-
def test_simple_inference_with_text_lora_denoiser_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ # Fusing should still keep the LoRA layers
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
- # Fusing should still keep the LoRA layers
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_denoiser_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ output_no_lora = self.get_base_pipe_output()
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- pipe.unload_lora_weights()
- # unloading should remove the LoRA layers
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
- )
- self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
+ pipe.unload_lora_weights()
+ # unloading should remove the LoRA layers
+ self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
+ self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly unloaded in text encoder 2",
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertFalse(
+ check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly unloaded in text encoder 2",
+ )
- output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
- "Fused lora should change the output",
- )
+ output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
+ "Fused lora should change the output",
+ )
def test_simple_inference_with_text_denoiser_lora_unfused(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
@@ -976,200 +836,162 @@ def test_simple_inference_with_text_denoiser_lora_unfused(
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
- output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+ output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
- output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+ output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- # unloading should remove the LoRA layers
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
+ # unloading should remove the LoRA layers
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
+ )
- # Fuse and unfuse should lead to the same results
- self.assertTrue(
- np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
- "Fused lora should not change the output",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertTrue(
+ np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
+ "Fused lora should not change the output",
+ )
def test_simple_inference_with_text_denoiser_multi_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- pipe.set_adapters("adapter-1")
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
- "Adapter outputs should be different.",
- )
-
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter outputs should be different.",
- )
-
- pipe.set_adapters(["adapter-1", "adapter-2"])
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter outputs should be different.",
- )
-
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
-
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
-
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
-
- pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
-
- def test_wrong_adapter_name_raises_error(self):
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
+ output_no_lora = self.get_base_pipe_output()
+
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
+ pipe.set_adapters("adapter-1")
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
+ "Adapter outputs should be different.",
+ )
+
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter outputs should be different.",
+ )
+
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter outputs should be different.",
+ )
+
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
+
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
+
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
+
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
+
+ def test_wrong_adapter_name_raises_error(self):
+ adapter_name = "adapter-1"
+
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
+ )
+
with self.assertRaises(ValueError) as err_context:
pipe.set_adapters("test")
self.assertTrue("not in the list of present adapters" in str(err_context.exception))
# test this works.
- pipe.set_adapters("adapter-1")
+ pipe.set_adapters(adapter_name)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_multiple_wrong_adapter_name_raises_error(self):
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ adapter_name = "adapter-1"
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
+ )
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
logger = logging.get_logger("diffusers.loaders.lora_base")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
- pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components)
+ pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
wrong_components = sorted(set(scale_with_wrong_components.keys()))
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
self.assertTrue(msg in str(cap_logger.out))
# test this works.
- pipe.set_adapters("adapter-1")
+ pipe.set_adapters(adapter_name)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_simple_inference_with_text_denoiser_block_scale(self):
@@ -1177,131 +999,127 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora)
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
- pipe.set_adapters("adapter-1", weights_1)
- output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
+ pipe.set_adapters("adapter-1", weights_1)
+ output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- weights_2 = {"unet": {"up": 5}}
- pipe.set_adapters("adapter-1", weights_2)
- output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ weights_2 = {"unet": {"up": 5}}
+ pipe.set_adapters("adapter-1", weights_2)
+ output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
- "LoRA weights 1 and 2 should give different results",
- )
- self.assertFalse(
- np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
- "No adapter and LoRA weights 1 should give different results",
- )
- self.assertFalse(
- np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
- "No adapter and LoRA weights 2 should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
+ "LoRA weights 1 and 2 should give different results",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
+ "No adapter and LoRA weights 1 should give different results",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
+ "No adapter and LoRA weights 2 should give different results",
+ )
- pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
- multiple adapters and set differnt weights for different blocks (i.e. block lora)
+ multiple adapters and set different weights for different blocks (i.e. block lora)
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
- scales_2 = {"unet": {"down": 5, "mid": 5}}
+ scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
+ scales_2 = {"unet": {"down": 5, "mid": 5}}
- pipe.set_adapters("adapter-1", scales_1)
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters("adapter-1", scales_1)
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-2", scales_2)
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters("adapter-2", scales_2)
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
- # a mismatching number of adapter_names and adapter_weights should raise an error
- with self.assertRaises(ValueError):
- pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
+ # a mismatching number of adapter_names and adapter_weights should raise an error
+ with self.assertRaises(ValueError):
+ pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
@@ -1397,170 +1215,164 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- pipe.set_adapters("adapter-1")
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters("adapter-1")
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters(["adapter-1", "adapter-2"])
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- pipe.delete_adapters("adapter-1")
- output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.delete_adapters("adapter-1")
+ output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ self.assertTrue(
+ np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- pipe.delete_adapters("adapter-2")
- output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.delete_adapters("adapter-2")
+ output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- pipe.set_adapters(["adapter-1", "adapter-2"])
- pipe.delete_adapters(["adapter-1", "adapter-2"])
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ pipe.delete_adapters(["adapter-1", "adapter-2"])
- output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- pipe.set_adapters("adapter-1")
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters("adapter-1")
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters(["adapter-1", "adapter-2"])
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
- output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
+ output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Weighted adapter and mixed adapter should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Weighted adapter and mixed adapter should give different results",
+ )
- pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
@skip_mps
@pytest.mark.xfail(
@@ -1569,36 +1381,41 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
strict=False,
)
def test_lora_fuse_nan(self):
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- # corrupt one LoRA weight with `inf` values
- with torch.no_grad():
- if self.unet_kwargs:
- pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
- "adapter-1"
- ].weight += float("inf")
- else:
- named_modules = [name for name, _ in pipe.transformer.named_modules()]
- tower_name = (
- "transformer_blocks"
- if any(name == "transformer_blocks" for name in named_modules)
- else "blocks"
- )
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ if self.unet_kwargs:
+ pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
+ "inf"
+ )
+ else:
+ named_modules = [name for name, _ in pipe.transformer.named_modules()]
+ possible_tower_names = [
+ "transformer_blocks",
+ "blocks",
+ "joint_transformer_blocks",
+ "single_transformer_blocks",
+ ]
+ filtered_tower_names = [
+ tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
+ ]
+ if len(filtered_tower_names) == 0:
+ reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
+ raise ValueError(reason)
+ for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
has_attn1 = any("attn1" in name for name in named_modules)
if has_attn1:
@@ -1606,118 +1423,115 @@ def test_lora_fuse_nan(self):
else:
transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
- # with `safe_fusing=True` we should see an Error
- with self.assertRaises(ValueError):
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
- # without we should not see an error, but every image will be black
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
- out = pipe(**inputs)[0]
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
+ out = pipe(**inputs)[0]
- self.assertTrue(np.isnan(out).all())
+ self.assertTrue(np.isnan(out).all())
def test_get_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- adapter_names = pipe.get_active_adapters()
- self.assertListEqual(adapter_names, ["adapter-1"])
+ adapter_names = pipe.get_active_adapters()
+ self.assertListEqual(adapter_names, ["adapter-1"])
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- adapter_names = pipe.get_active_adapters()
- self.assertListEqual(adapter_names, ["adapter-2"])
+ adapter_names = pipe.get_active_adapters()
+ self.assertListEqual(adapter_names, ["adapter-2"])
- pipe.set_adapters(["adapter-1", "adapter-2"])
- self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
def test_get_list_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
- # 1.
- dicts_to_be_checked = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
+ # 1.
+ dicts_to_be_checked = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
- if self.unet_kwargs is not None:
- pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
- dicts_to_be_checked.update({"unet": ["adapter-1"]})
- else:
- pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
- dicts_to_be_checked.update({"transformer": ["adapter-1"]})
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
+ dicts_to_be_checked.update({"unet": ["adapter-1"]})
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+ dicts_to_be_checked.update({"transformer": ["adapter-1"]})
- self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
+ self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
- # 2.
- dicts_to_be_checked = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
+ # 2.
+ dicts_to_be_checked = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
- if self.unet_kwargs is not None:
- pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
- dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
- else:
- pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
- dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
+ dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
+ dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
- self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
+ self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
- # 3.
- pipe.set_adapters(["adapter-1", "adapter-2"])
+ # 3.
+ pipe.set_adapters(["adapter-1", "adapter-2"])
- dicts_to_be_checked = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
+ dicts_to_be_checked = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
- if self.unet_kwargs is not None:
- dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
- else:
- dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
+ if self.unet_kwargs is not None:
+ dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
+ else:
+ dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
- self.assertDictEqual(
- pipe.get_list_adapters(),
- dicts_to_be_checked,
- )
+ self.assertDictEqual(
+ pipe.get_list_adapters(),
+ dicts_to_be_checked,
+ )
- # 4.
- dicts_to_be_checked = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
+ # 4.
+ dicts_to_be_checked = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
- if self.unet_kwargs is not None:
- pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
- dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
- else:
- pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
- dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
+ dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
+ dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
- self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
+ self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
- @require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
):
@@ -1725,111 +1539,149 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ # set them to multi-adapter inference mode
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- # Attach a second adapter
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ pipe.set_adapters(["adapter-1"])
+ outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
+ self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ # Fusing should still keep the LoRA layers so output should remain the same
+ outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ self.assertTrue(
+ np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
+ "Fused lora should not change the output",
+ )
- # set them to multi-adapter inference mode
- pipe.set_adapters(["adapter-1", "adapter-2"])
- outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
- pipe.set_adapters(["adapter-1"])
- outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
- # Fusing should still keep the LoRA layers so outpout should remain the same
- outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
+ )
- self.assertTrue(
- np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
- "Fused lora should not change the output",
- )
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"])
+ self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
- pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
- pipe.fuse_lora(
- components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
- )
+ # Fusing should still keep the LoRA layers
+ output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
+ "Fused lora should not change the output",
+ )
+ pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
- # Fusing should still keep the LoRA layers
- output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
- "Fused lora should not change the output",
- )
+ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
+ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
- @require_peft_version_greater(peft_version="0.9.0")
- def test_simple_inference_with_dora(self):
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
- scheduler_cls, use_dora=True
- )
+ for lora_scale in [1.0, 0.8]:
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_dora_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ )
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly set in text encoder 2",
)
- output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters(["adapter-1"])
+ attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
+ outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+
+ pipe.fuse_lora(
+ components=self.pipeline_class._lora_loadable_modules,
+ adapter_names=["adapter-1"],
+ lora_scale=lora_scale,
+ )
+ self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+
+ outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
+ "Fused lora should not change the output",
+ )
self.assertFalse(
- np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
- "DoRA lora should change the output",
+ np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
+ "LoRA should change the output",
)
+ def test_simple_inference_with_dora(self):
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_dora_lora.shape == self.output_shape)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+
+ output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(
+ np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
+ "DoRA lora should change the output",
+ )
+
def test_missing_keys_warning(self):
- scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
- components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1853,7 +1705,7 @@ def test_missing_keys_warning(self):
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]
- logger = logging.get_logger("diffusers.loaders.peft")
+ logger = logging.get_logger("diffusers.utils.peft_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
@@ -1864,9 +1716,8 @@ def test_missing_keys_warning(self):
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
def test_unexpected_keys_warning(self):
- scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
- components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1888,7 +1739,7 @@ def test_unexpected_keys_warning(self):
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
- logger = logging.get_logger("diffusers.loaders.peft")
+ logger = logging.get_logger("diffusers.utils.peft_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
@@ -1901,34 +1752,21 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
- # Just makes sure it works..
- _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ # Just makes sure it works.
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_modify_padding_mode(self):
def set_pad_mode(network, mode="circular"):
@@ -1936,28 +1774,26 @@ def set_pad_mode(network, mode="circular"):
if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode
- for scheduler_cls in self.scheduler_classes:
- components, _, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _pad_mode = "circular"
- set_pad_mode(pipe.vae, _pad_mode)
- set_pad_mode(pipe.unet, _pad_mode)
+ components, _, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _pad_mode = "circular"
+ set_pad_mode(pipe.vae, _pad_mode)
+ set_pad_mode(pipe.unet, _pad_mode)
- _, _, inputs = self.get_dummy_inputs()
- _ = pipe(**inputs)[0]
+ _, _, inputs = self.get_dummy_inputs()
+ _ = pipe(**inputs)[0]
def test_logs_info_when_no_lora_keys_found(self):
- scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
- components, _, _ = self.get_dummy_components(scheduler_cls)
+ components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1969,7 +1805,7 @@ def test_logs_info_when_no_lora_keys_found(self):
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
- self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
+ self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
# test only for text encoder
for lora_module in self.pipeline_class._lora_loadable_modules:
@@ -1994,94 +1830,70 @@ def test_logs_info_when_no_lora_keys_found(self):
def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
- call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
- for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
- if possible_attention_kwargs in call_signature_keys:
- attention_kwargs_name = possible_attention_kwargs
- break
- assert attention_kwargs_name is not None
+ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ output_no_lora = self.get_base_pipe_output()
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ lora_scale = 0.5
+ attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
+ output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ pipe.set_adapters("default", lora_scale)
+ output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
+ "Lora + scale should match the output of `set_adapters()`.",
+ )
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
- lora_scale = 0.5
- attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
- output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- self.assertFalse(
- np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
- pipe.set_adapters("default", lora_scale)
- output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
- not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
+ not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
- np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
- "Lora + scale should match the output of `set_adapters()`.",
+ np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results as attention_kwargs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results as set_adapters().",
)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
- )
-
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
-
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
-
- output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- self.assertTrue(
- not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
- self.assertTrue(
- np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results as attention_kwargs.",
- )
- self.assertTrue(
- np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results as set_adapters().",
- )
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -2090,15 +1902,12 @@ def test_lora_B_bias(self):
bias_values = {}
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, module in denoiser.named_modules():
- if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
+ if any(k in name for k in self.denoiser_target_modules):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- logger = logging.get_logger("diffusers.loaders.lora_pipeline")
- logger.setLevel(logging.INFO)
-
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
denoiser_lora_config.lora_bias = False
@@ -2121,7 +1930,7 @@ def test_lora_B_bias(self):
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
def test_correct_lora_configs_with_different_ranks(self):
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -2188,14 +1997,15 @@ def test_correct_lora_configs_with_different_ranks(self):
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_layerwise_casting_inference_denoiser(self):
- from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
- if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
@@ -2206,27 +2016,12 @@ def check_linear_dtype(module, storage_dtype, compute_dtype):
self.assertEqual(submodule.bias.dtype, dtype_to_check)
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
@@ -2261,10 +2056,10 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
"""
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import (
_PEFT_AUTOCAST_DISABLE_HOOK,
DEFAULT_SKIP_MODULES_PATTERN,
- SUPPORTED_PYTORCH_LAYERS,
apply_layerwise_casting,
)
@@ -2274,7 +2069,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self):
def check_module(denoiser):
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
for name, module in denoiser.named_modules():
- if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
+ if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
@@ -2288,7 +2083,7 @@ def check_module(denoiser):
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
# 1. Test forward with add_adapter
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
@@ -2318,7 +2113,7 @@ def check_module(denoiser):
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
- components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
@@ -2335,3 +2130,289 @@ def check_module(denoiser):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ @parameterized.expand([4, 8, 16])
+ def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
+ pipe = self.pipeline_class(**components)
+
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
+ pipe.unload_lora_weights()
+
+ out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
+ if len(out) == 3:
+ _, _, parsed_metadata = out
+ elif len(out) == 2:
+ _, parsed_metadata = out
+
+ denoiser_key = (
+ f"{self.pipeline_class.transformer_name}"
+ if self.transformer_kwargs is not None
+ else f"{self.pipeline_class.unet_name}"
+ )
+ self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
+ check_module_lora_metadata(
+ parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
+ )
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ text_encoder_key = self.pipeline_class.text_encoder_name
+ self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
+ check_module_lora_metadata(
+ parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
+ )
+
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ text_encoder_2_key = "text_encoder_2"
+ self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
+ check_module_lora_metadata(
+ parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
+ )
+
+ @parameterized.expand([4, 8, 16])
+ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(tmpdir)
+
+ output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
+ )
+
+ def test_lora_unload_add_adapter(self):
+ """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ # unload and then add.
+ pipe.unload_lora_weights()
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ def test_inference_load_delete_load_adapters(self):
+ "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = self.get_base_pipe_output()
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+
+ # First, delete adapter and compare.
+ pipe.delete_adapters(pipe.get_active_adapters()[0])
+ output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
+ self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
+
+ # Then load adapter and compare.
+ pipe.load_lora_weights(tmpdirname)
+ output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
+
+ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
+ from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
+
+ onload_device = torch_device
+ offload_device = torch.device("cpu")
+
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+
+ components, _, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ check_if_lora_correctly_set(denoiser)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ # Test group offloading with load_lora_weights
+ denoiser.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type=offload_type,
+ num_blocks_per_group=1,
+ use_stream=use_stream,
+ )
+ # Place other model-level components on `torch_device`.
+ for _, component in pipe.components.items():
+ if isinstance(component, torch.nn.Module):
+ component.to(torch_device)
+ group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
+ self.assertTrue(group_offload_hook_1 is not None)
+ output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ # Test group offloading after removing the lora
+ pipe.unload_lora_weights()
+ group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
+ self.assertTrue(group_offload_hook_2 is not None)
+ output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
+
+ # Add the lora again and check if group offloading works
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ check_if_lora_correctly_set(denoiser)
+ group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
+ self.assertTrue(group_offload_hook_3 is not None)
+ output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
+
+ @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
+ @require_torch_accelerator
+ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
+ for cls in inspect.getmro(self.__class__):
+ if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
+ # Skip this test if it is overwritten by child class. We need to do this because parameterized
+ # materializes the test methods on invocation which cannot be overridden.
+ return
+ self._test_group_offloading_inference_denoiser(offload_type, use_stream)
+
+ @require_torch_accelerator
+ def test_lora_loading_model_cpu_offload(self):
+ components, _, denoiser_lora_config = self.get_dummy_components()
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
+ # reinitialize the pipeline to mimic the inference workflow.
+ components, _, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.enable_model_cpu_offload(device=torch_device)
+ pipe.load_lora_weights(tmpdirname)
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))
+
+ @require_torch_accelerator
+ def test_lora_group_offloading_delete_adapters(self):
+ components, _, denoiser_lora_config = self.get_dummy_components()
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ try:
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
+
+ components, _, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ pipe.to(torch_device)
+
+ # Enable Group Offloading (leaf_level for more granular testing)
+ apply_group_offloading(
+ denoiser,
+ onload_device=torch_device,
+ offload_device="cpu",
+ offload_type="leaf_level",
+ )
+
+ pipe.load_lora_weights(tmpdirname, adapter_name="default")
+
+ out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ # Delete the adapter
+ pipe.delete_adapters("default")
+
+ out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3))
+ finally:
+ # Clean up the hooks to prevent state leak
+ if hasattr(denoiser, "_diffusers_hook"):
+ denoiser._diffusers_hook.remove_hook(_GROUP_OFFLOADING, recurse=True)
diff --git a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
index 7efb390287ab..2476ab92f77a 100644
--- a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
+++ b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,9 @@
from diffusers import AsymmetricAutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -33,14 +35,14 @@
torch_all_close,
torch_device,
)
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
@@ -134,18 +136,32 @@ def get_generator(self, seed=0):
# fmt: off
[
33,
- [-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205],
- [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824],
+ Expectations(
+ {
+ ("xpu", 3): torch.tensor([-0.0343, 0.2873, 0.1680, -0.0140, -0.3459, 0.3522, -0.1336, 0.1075]),
+ ("cuda", 7): torch.tensor([-0.0336, 0.3011, 0.1764, 0.0087, -0.3401, 0.3645, -0.1247, 0.1205]),
+ ("mps", None): torch.tensor(
+ [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]
+ ),
+ }
+ ),
],
[
47,
- [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529],
- [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089],
+ Expectations(
+ {
+ ("xpu", 3): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
+ ("cuda", 7): torch.tensor([0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529]),
+ ("mps", None): torch.tensor(
+ [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]
+ ),
+ }
+ ),
],
# fmt: on
]
)
- def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
+ def test_stable_diffusion(self, seed, expected_slices):
model = self.get_sd_vae_model()
image = self.get_sd_image(seed)
generator = self.get_generator(seed)
@@ -156,9 +172,9 @@ def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
assert sample.shape == image.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
- expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
- assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
+ expected_slice = expected_slices.get_expectation()
+ assert torch_all_close(output_slice, expected_slice, atol=5e-3)
@parameterized.expand(
[
diff --git a/tests/models/autoencoders/test_models_autoencoder_cosmos.py b/tests/models/autoencoders/test_models_autoencoder_cosmos.py
new file mode 100644
index 000000000000..5898ae776a1b
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_cosmos.py
@@ -0,0 +1,83 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderKLCosmos
+
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLCosmos
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_cosmos_config(self):
+ return {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 4,
+ "encoder_block_out_channels": (8, 8, 8, 8),
+ "decode_block_out_channels": (8, 8, 8, 8),
+ "attention_resolutions": (8,),
+ "resolution": 64,
+ "num_layers": 2,
+ "patch_size": 4,
+ "patch_type": "haar",
+ "scaling_factor": 1.0,
+ "spatial_compression_ratio": 4,
+ "temporal_compression_ratio": 4,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ height = 32
+ width = 32
+
+ image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 9, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 9, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_cosmos_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "CosmosEncoder3d",
+ "CosmosDecoder3d",
+ }
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ @unittest.skip("Not sure why this test fails. Investigate later.")
+ def test_effective_gradient_checkpointing(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py
index 5f21593d8e04..b1b5531d0134 100644
--- a/tests/models/autoencoders/test_models_autoencoder_dc.py
+++ b/tests/models/autoencoders/test_models_autoencoder_dc.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,19 +16,16 @@
import unittest
from diffusers import AutoencoderDC
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- torch_device,
-)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
@@ -82,6 +79,10 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- @unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
- def test_forward_with_norm_groups(self):
- pass
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_memory(self):
+ super().test_layerwise_casting_memory()
diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
index 00d4b8ed2b5f..9813772a7c55 100644
--- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
+++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,19 +19,16 @@
from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- torch_device,
-)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -87,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"HunyuanVideoDecoder3D",
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py
index 9126594000f6..5f11c6cb0ab3 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,8 @@
from diffusers import AutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -34,14 +35,14 @@
torch_all_close,
torch_device,
)
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
@@ -83,68 +84,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
index 7336bb3d3e97..b6d59489d9c6 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,19 +18,20 @@
import torch
from diffusers import AutoencoderKLCogVideoX
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX
main_input_name = "sample"
base_precision = 1e-2
@@ -82,68 +83,6 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CogVideoXDownBlock3D",
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
index cf80ff50443e..93f40f44a919 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,19 +16,20 @@
import unittest
from diffusers import AutoencoderKLTemporalDecoder
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
@@ -67,7 +68,3 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
-
- @unittest.skip("Test unsupported.")
- def test_forward_with_norm_groups(self):
- pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
index 66d170b28eee..527be1b4ecb5 100644
--- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
+++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,19 +18,20 @@
import torch
from diffusers import AutoencoderKLLTXVideo
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -99,7 +100,7 @@ def test_forward_with_norm_groups(self):
pass
-class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -167,34 +168,3 @@ def test_outputs_equivalence(self):
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
-
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
diff --git a/tests/models/autoencoders/test_models_autoencoder_magvit.py b/tests/models/autoencoders/test_models_autoencoder_magvit.py
index ee7e5bbdd485..f7304df14048 100644
--- a/tests/models/autoencoders/test_models_autoencoder_magvit.py
+++ b/tests/models/autoencoders/test_models_autoencoder_magvit.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,15 +16,16 @@
import unittest
from diffusers import AutoencoderKLMagvit
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2
@@ -88,3 +89,9 @@ def test_effective_gradient_checkpointing(self):
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
+
+ @unittest.skip(
+ "Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
+ )
+ def test_enable_disable_slicing(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_mochi.py b/tests/models/autoencoders/test_models_autoencoder_mochi.py
new file mode 100755
index 000000000000..ab8d429a67f6
--- /dev/null
+++ b/tests/models/autoencoders/test_models_autoencoder_mochi.py
@@ -0,0 +1,100 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from diffusers import AutoencoderKLMochi
+
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
+
+
+enable_full_determinism()
+
+
+class AutoencoderKLMochiTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
+ model_class = AutoencoderKLMochi
+ main_input_name = "sample"
+ base_precision = 1e-2
+
+ def get_autoencoder_kl_mochi_config(self):
+ return {
+ "in_channels": 15,
+ "out_channels": 3,
+ "latent_channels": 4,
+ "encoder_block_out_channels": (32, 32, 32, 32),
+ "decoder_block_out_channels": (32, 32, 32, 32),
+ "layers_per_block": (1, 1, 1, 1, 1),
+ "act_fn": "silu",
+ "scaling_factor": 1,
+ }
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_frames = 7
+ num_channels = 3
+ sizes = (16, 16)
+
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+
+ return {"sample": image}
+
+ @property
+ def input_shape(self):
+ return (3, 7, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (3, 7, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = self.get_autoencoder_kl_mochi_config()
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {
+ "MochiDecoder3D",
+ "MochiDownBlock3D",
+ "MochiEncoder3D",
+ "MochiMidBlock3D",
+ "MochiUpBlock3D",
+ }
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ @unittest.skip("Unsupported test.")
+ def test_model_parallelism(self):
+ """
+ tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
+ RuntimeError: values expected sparse tensor layout but got Strided
+ """
+ pass
+
+ @unittest.skip("Unsupported test.")
+ def test_outputs_equivalence(self):
+ """
+ tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_outputs_equivalence -
+ RuntimeError: values expected sparse tensor layout but got Strided
+ """
+ pass
+
+ @unittest.skip("Unsupported test.")
+ def test_sharded_checkpoints_device_map(self):
+ """
+ tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_sharded_checkpoints_device_map -
+ RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:5!
+ """
diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py
index 2adea6bda439..d10e8ba33a12 100644
--- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py
+++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,8 @@
from parameterized import parameterized
from diffusers import AutoencoderOobleck
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -29,14 +30,14 @@
torch_all_close,
torch_device,
)
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderOobleckTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderOobleck
main_input_name = "sample"
base_precision = 1e-2
@@ -106,10 +107,6 @@ def test_enable_disable_slicing(self):
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
- @unittest.skip("Test unsupported.")
- def test_forward_with_norm_groups(self):
- pass
-
@unittest.skip("No attention module used in this model")
def test_set_attn_processor_for_determinism(self):
return
diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py
index bfbfb7ab8593..68232aa12fdf 100644
--- a/tests/models/autoencoders/test_models_autoencoder_tiny.py
+++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,8 @@
from parameterized import parameterized
from diffusers import AutoencoderTiny
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -30,14 +31,14 @@
torch_all_close,
torch_device,
)
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderTiny
main_input_name = "sample"
base_precision = 1e-2
@@ -81,37 +82,6 @@ def prepare_init_args_and_inputs_for_common(self):
def test_enable_disable_tiling(self):
pass
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict)[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict)[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict)[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
@unittest.skip("Test not supported.")
def test_outputs_equivalence(self):
pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py
index ffc474039889..051098dc7aac 100644
--- a/tests/models/autoencoders/test_models_autoencoder_wan.py
+++ b/tests/models/autoencoders/test_models_autoencoder_wan.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,15 +16,16 @@
import unittest
from diffusers import AutoencoderKLWan
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLWan
main_input_name = "sample"
base_precision = 1e-2
@@ -44,9 +45,16 @@ def dummy_input(self):
num_frames = 9
num_channels = 3
sizes = (16, 16)
-
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
+ return {"sample": image}
+ @property
+ def dummy_input_tiling(self):
+ batch_size = 2
+ num_frames = 9
+ num_channels = 3
+ sizes = (128, 128)
+ image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
@@ -62,6 +70,11 @@ def prepare_init_args_and_inputs_for_common(self):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
+ def prepare_init_args_and_inputs_for_tiling(self):
+ init_dict = self.get_autoencoder_kl_wan_config()
+ inputs_dict = self.dummy_input_tiling
+ return init_dict, inputs_dict
+
@unittest.skip("Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self):
pass
diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py
index 77977a78d83b..ef04d151ecd1 100644
--- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py
+++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,22 +20,24 @@
import torch
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_image,
slow,
torch_all_close,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
+class ConsistencyDecoderVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
main_input_name = "sample"
base_precision = 1e-2
@@ -91,70 +93,6 @@ def init_dict(self):
def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict()
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
- _ = inputs_dict.pop("generator")
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
- _ = inputs_dict.pop("generator")
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
@slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
@@ -162,13 +100,13 @@ def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@torch.no_grad()
def test_encode_decode(self):
diff --git a/tests/models/autoencoders/test_models_vae_flax.py b/tests/models/autoencoders/test_models_vae_flax.py
deleted file mode 100644
index 8fedb85eccfc..000000000000
--- a/tests/models/autoencoders/test_models_vae_flax.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import unittest
-
-from diffusers import FlaxAutoencoderKL
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import require_flax
-
-from ..test_modeling_common_flax import FlaxModelTesterMixin
-
-
-if is_flax_available():
- import jax
-
-
-@require_flax
-class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
- model_class = FlaxAutoencoderKL
-
- @property
- def dummy_input(self):
- batch_size = 4
- num_channels = 3
- sizes = (32, 32)
-
- prng_key = jax.random.PRNGKey(0)
- image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
-
- return {"sample": image, "prng_key": prng_key}
-
- def prepare_init_args_and_inputs_for_common(self):
- init_dict = {
- "block_out_channels": [32, 64],
- "in_channels": 3,
- "out_channels": 3,
- "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
- "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
- "latent_channels": 4,
- }
- inputs_dict = self.dummy_input
- return init_dict, inputs_dict
diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py
index 77abe139d785..b88d24d1f2d8 100644
--- a/tests/models/autoencoders/test_models_vq.py
+++ b/tests/models/autoencoders/test_models_vq.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,20 +18,16 @@
import torch
from diffusers import VQModel
-from diffusers.utils.testing_utils import (
- backend_manual_seed,
- enable_full_determinism,
- floats_tensor,
- torch_device,
-)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import backend_manual_seed, enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class VQModelTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = VQModel
main_input_name = "sample"
diff --git a/tests/models/autoencoders/testing_utils.py b/tests/models/autoencoders/testing_utils.py
new file mode 100644
index 000000000000..8ae362ac2e94
--- /dev/null
+++ b/tests/models/autoencoders/testing_utils.py
@@ -0,0 +1,147 @@
+import inspect
+
+import numpy as np
+import pytest
+import torch
+
+from diffusers.models.autoencoders.vae import DecoderOutput
+from diffusers.utils.torch_utils import torch_device
+
+
+class AutoencoderTesterMixin:
+ """
+ Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks
+ usually don't do slicing and tiling.
+ """
+
+ @staticmethod
+ def _accepts_generator(model):
+ model_sig = inspect.signature(model.forward)
+ accepts_generator = "generator" in model_sig.parameters
+ return accepts_generator
+
+ @staticmethod
+ def _accepts_norm_num_groups(model_class):
+ model_sig = inspect.signature(model_class.__init__)
+ accepts_norm_groups = "norm_num_groups" in model_sig.parameters
+ return accepts_norm_groups
+
+ def test_forward_with_norm_groups(self):
+ if not self._accepts_norm_num_groups(self.model_class):
+ pytest.skip(f"Test not supported for {self.model_class.__name__}")
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["norm_num_groups"] = 16
+ init_dict["block_out_channels"] = (16, 32)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ def test_enable_disable_tiling(self):
+ if not hasattr(self.model_class, "enable_tiling"):
+ pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ if not hasattr(model, "use_tiling"):
+ pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
+
+ inputs_dict.update({"return_dict": False})
+ _ = inputs_dict.pop("generator", None)
+ accepts_generator = self._accepts_generator(model)
+
+ torch.manual_seed(0)
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict)[0]
+ # Mochi-1
+ if isinstance(output_without_tiling, DecoderOutput):
+ output_without_tiling = output_without_tiling.sample
+
+ torch.manual_seed(0)
+ model.enable_tiling()
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_with_tiling = model(**inputs_dict)[0]
+ if isinstance(output_with_tiling, DecoderOutput):
+ output_with_tiling = output_with_tiling.sample
+
+ assert (
+ output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
+ ).max() < 0.5, "VAE tiling should not affect the inference results"
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_without_tiling_2 = model(**inputs_dict)[0]
+ if isinstance(output_without_tiling_2, DecoderOutput):
+ output_without_tiling_2 = output_without_tiling_2.sample
+
+ assert np.allclose(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ ), "Without tiling outputs should match with the outputs when tiling is manually disabled."
+
+ def test_enable_disable_slicing(self):
+ if not hasattr(self.model_class, "enable_slicing"):
+ pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+ if not hasattr(model, "use_slicing"):
+ pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
+
+ inputs_dict.update({"return_dict": False})
+ _ = inputs_dict.pop("generator", None)
+ accepts_generator = self._accepts_generator(model)
+
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict)[0]
+ # Mochi-1
+ if isinstance(output_without_slicing, DecoderOutput):
+ output_without_slicing = output_without_slicing.sample
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_with_slicing = model(**inputs_dict)[0]
+ if isinstance(output_with_slicing, DecoderOutput):
+ output_with_slicing = output_with_slicing.sample
+
+ assert (
+ output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
+ ).max() < 0.5, "VAE slicing should not affect the inference results"
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_without_slicing_2 = model(**inputs_dict)[0]
+ if isinstance(output_without_slicing_2, DecoderOutput):
+ output_without_slicing_2 = output_without_slicing_2.sample
+
+ assert np.allclose(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ ), "Without slicing outputs should match with the outputs when slicing is manually disabled."
diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py
index d070f6ea33e3..ccf36b092b46 100644
--- a/tests/models/test_attention_processor.py
+++ b/tests/models/test_attention_processor.py
@@ -7,7 +7,8 @@
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
-from diffusers.utils.testing_utils import torch_device
+
+from ..testing_utils import torch_device
class AttnAddedKVProcessorTests(unittest.TestCase):
diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py
index 415bb12b73c6..eaeffa699db2 100644
--- a/tests/models/test_layers_utils.py
+++ b/tests/models/test_layers_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,8 @@
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformers.transformer_2d import Transformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_manual_seed,
require_torch_accelerator_with_fp64,
require_torch_version_greater_equal,
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index fc4a3128dd9f..b9dfe932335c 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
import copy
import gc
+import glob
import inspect
import json
import os
@@ -28,22 +29,25 @@
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
+import pytest
import requests_mock
+import safetensors.torch
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
-from huggingface_hub import ModelCard, delete_repo, snapshot_download
-from huggingface_hub.utils import is_jinja_available
+from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
+from huggingface_hub.utils import HfHubHTTPError, is_jinja_available
from parameterized import parameterized
-from requests.exceptions import HTTPError
-from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
+from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
AttnProcessorNPU,
XFormersAttnProcessor,
)
+from diffusers.models.auto_model import AutoModel
+from diffusers.models.modeling_outputs import BaseOutput
from diffusers.training_utils import EMAModel
from diffusers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
@@ -54,23 +58,32 @@
logging,
)
from diffusers.utils.hub_utils import _add_variant
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import get_torch_cuda_device_capability
+
+from ..others.test_utils import TOKEN, USER, is_staging_test
+from ..testing_utils import (
CaptureLogger,
+ _check_safetensors_serialization,
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ backend_synchronize,
+ check_if_dicts_are_equal,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
+ require_peft_backend,
+ require_peft_version_greater,
require_torch_2,
require_torch_accelerator,
require_torch_accelerator_with_training,
- require_torch_gpu,
require_torch_multi_accelerator,
+ require_torch_version_greater,
run_test_in_subprocess,
+ slow,
torch_all_close,
torch_device,
)
-from diffusers.utils.torch_utils import get_torch_cuda_device_capability
-
-from ..others.test_utils import TOKEN, USER, is_staging_test
if is_peft_available():
@@ -96,6 +109,11 @@ def check_if_lora_correctly_set(model) -> bool:
return False
+def normalize_output(out):
+ out0 = out[0] if isinstance(out, (BaseOutput, tuple)) else out
+ return torch.stack(out0) if isinstance(out0, list) else out0
+
+
# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None
@@ -230,8 +248,8 @@ def load_model(path):
else:
_ = load_model(repo_id)
- warning_message = str(warning.warnings[0].message)
- self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message)
+ warning_messages = " ".join(str(w.message) for w in warning.warnings)
+ self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_messages)
# Local tests are already covered down below.
@parameterized.expand(
@@ -258,7 +276,7 @@ def test_cached_files_are_used_when_no_internet(self):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
@@ -277,6 +295,56 @@ def test_cached_files_are_used_when_no_internet(self):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
+ def test_local_files_only_with_sharded_checkpoint(self):
+ repo_id = "hf-internal-testing/tiny-flux-sharded"
+ error_response = mock.Mock(
+ status_code=500,
+ headers={},
+ raise_for_status=mock.Mock(side_effect=HfHubHTTPError("Server down", response=mock.Mock())),
+ json=mock.Mock(return_value={}),
+ )
+ client_mock = mock.Mock()
+ client_mock.get.return_value = error_response
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)
+
+ with mock.patch("huggingface_hub.hf_api.get_session", return_value=client_mock):
+ # Should fail with local_files_only=False (network required)
+ # We would make a network call with model_info
+ with self.assertRaises(OSError):
+ FluxTransformer2DModel.from_pretrained(
+ repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
+ )
+
+ # Should succeed with local_files_only=True (uses cache)
+ # model_info call skipped
+ local_model = FluxTransformer2DModel.from_pretrained(
+ repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
+ )
+
+ assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
+ "Model parameters don't match!"
+ )
+
+ # Remove a shard file
+ cached_shard_file = try_to_load_from_cache(
+ repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir
+ )
+ os.remove(cached_shard_file)
+
+ # Attempting to load from cache should raise an error
+ with self.assertRaises(OSError) as context:
+ FluxTransformer2DModel.from_pretrained(
+ repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
+ )
+
+ # Verify error mentions the missing shard
+ error_msg = str(context.exception)
+ assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
+ f"Expected error about missing shard, got: {error_msg}"
+ )
+
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
def test_one_request_upon_cached(self):
@@ -292,9 +360,9 @@ def test_one_request_upon_cached(self):
)
download_requests = [r.method for r in m.request_history]
- assert (
- download_requests.count("HEAD") == 3
- ), "3 HEAD requests one for config, one for model, and one for shard index file."
+ assert download_requests.count("HEAD") == 3, (
+ "3 HEAD requests one for config, one for model, and one for shard index file."
+ )
assert download_requests.count("GET") == 2, "2 GET requests one for config, one for model"
with requests_mock.mock(real_http=True) as m:
@@ -306,9 +374,9 @@ def test_one_request_upon_cached(self):
)
cache_requests = [r.method for r in m.request_history]
- assert (
- "HEAD" == cache_requests[0] and len(cache_requests) == 2
- ), "We should call only `model_info` to check for commit hash and knowing if shard index is present."
+ assert "HEAD" == cache_requests[0] and len(cache_requests) == 2, (
+ "We should call only `model_info` to check for commit hash and knowing if shard index is present."
+ )
def test_weight_overwrite(self):
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
@@ -334,7 +402,7 @@ def test_weight_overwrite(self):
assert model.config.in_channels == 9
- @require_torch_gpu
+ @require_torch_accelerator
def test_keep_modules_in_fp32(self):
r"""
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32 when we load the model in fp16/bf16
@@ -388,7 +456,15 @@ def get_dummy_inputs():
class UNetTesterMixin:
+ @staticmethod
+ def _accepts_norm_num_groups(model_class):
+ model_sig = inspect.signature(model_class.__init__)
+ accepts_norm_groups = "norm_num_groups" in model_sig.parameters
+ return accepts_norm_groups
+
def test_forward_with_norm_groups(self):
+ if not self._accepts_norm_num_groups(self.model_class):
+ pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
@@ -466,6 +542,9 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5):
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
+ image = normalize_output(image)
+ new_image = normalize_output(new_image)
+
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
@@ -710,6 +789,9 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
+ image = normalize_output(image)
+ new_image = normalize_output(new_image)
+
max_diff = (image - new_image).abs().max().item()
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
@@ -772,6 +854,9 @@ def test_determinism(self, expected_max_diff=1e-5):
if isinstance(second, dict):
second = second.to_tuple()[0]
+ first = normalize_output(first)
+ second = normalize_output(second)
+
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
@@ -790,11 +875,15 @@ def test_output(self, expected_output_shape=None):
if isinstance(output, dict):
output = output.to_tuple()[0]
+ if isinstance(output, list):
+ output = torch.stack(output)
self.assertIsNotNone(output)
# input & output have to have the same shape
input_tensor = inputs_dict[self.main_input_name]
+ if isinstance(input_tensor, list):
+ input_tensor = torch.stack(input_tensor)
if expected_output_shape is None:
expected_shape = input_tensor.shape
@@ -828,11 +917,15 @@ def test_model_from_pretrained(self):
if isinstance(output_1, dict):
output_1 = output_1.to_tuple()[0]
+ if isinstance(output_1, list):
+ output_1 = torch.stack(output_1)
output_2 = new_model(**inputs_dict)
if isinstance(output_2, dict):
output_2 = output_2.to_tuple()[0]
+ if isinstance(output_2, list):
+ output_2 = torch.stack(output_2)
self.assertEqual(output_1.shape, output_2.shape)
@@ -926,8 +1019,9 @@ def recursive_check(tuple_object, dict_object):
@require_torch_accelerator_with_training
def test_enable_disable_gradient_checkpointing(self):
+ # Skip test if model does not support gradient checkpointing
if not self.model_class._supports_gradient_checkpointing:
- return # Skip test if model does not support gradient checkpointing
+ pytest.skip("Gradient checkpointing is not supported.")
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
@@ -945,8 +1039,9 @@ def test_enable_disable_gradient_checkpointing(self):
@require_torch_accelerator_with_training
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
+ # Skip test if model does not support gradient checkpointing
if not self.model_class._supports_gradient_checkpointing:
- return # Skip test if model does not support gradient checkpointing
+ pytest.skip("Gradient checkpointing is not supported.")
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1003,8 +1098,9 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
def test_gradient_checkpointing_is_applied(
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
):
+ # Skip test if model does not support gradient checkpointing
if not self.model_class._supports_gradient_checkpointing:
- return # Skip test if model does not support gradient checkpointing
+ pytest.skip("Gradient checkpointing is not supported.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1048,11 +1144,10 @@ def test_deprecated_kwargs(self):
" from `_deprecated_kwargs = []`"
)
- @parameterized.expand([True, False])
+ @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
- def test_save_load_lora_adapter(self, use_dora=False):
- import safetensors
+ def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
@@ -1062,14 +1157,16 @@ def test_save_load_lora_adapter(self, use_dora=False):
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
- return
+ pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
+ if isinstance(output_no_lora, list):
+ output_no_lora = torch.stack(output_no_lora)
denoiser_lora_config = LoraConfig(
- r=4,
- lora_alpha=4,
+ r=rank,
+ lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
@@ -1079,6 +1176,8 @@ def test_save_load_lora_adapter(self, use_dora=False):
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
+ if isinstance(outputs_with_lora, list):
+ outputs_with_lora = torch.stack(outputs_with_lora)
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
@@ -1103,12 +1202,14 @@ def test_save_load_lora_adapter(self, use_dora=False):
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
+ if isinstance(outputs_with_lora_2, list):
+ outputs_with_lora_2 = torch.stack(outputs_with_lora_2)
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
- def test_wrong_adapter_name_raises_error(self):
+ def test_lora_wrong_adapter_name_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
@@ -1117,7 +1218,7 @@ def test_wrong_adapter_name_raises_error(self):
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
- return
+ pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig(
r=4,
@@ -1136,47 +1237,136 @@ def test_wrong_adapter_name_raises_error(self):
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
+ @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
+ @torch.no_grad()
+ @unittest.skipIf(not is_peft_available(), "Only with PEFT")
+ def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
+ from peft import LoraConfig
+
+ from diffusers.loaders.peft import PeftAdapterMixin
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ if not issubclass(model.__class__, PeftAdapterMixin):
+ pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
+
+ denoiser_lora_config = LoraConfig(
+ r=rank,
+ lora_alpha=lora_alpha,
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
+ init_lora_weights=False,
+ use_dora=use_dora,
+ )
+ model.add_adapter(denoiser_lora_config)
+ metadata = model.peft_config["default"].to_dict()
+ self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model.save_lora_adapter(tmpdir)
+ model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
+ self.assertTrue(os.path.isfile(model_file))
+
+ model.unload_lora()
+ self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
+
+ model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
+ parsed_metadata = model.peft_config["default_0"].to_dict()
+ check_if_dicts_are_equal(metadata, parsed_metadata)
+
+ @torch.no_grad()
+ @unittest.skipIf(not is_peft_available(), "Only with PEFT")
+ def test_lora_adapter_wrong_metadata_raises_error(self):
+ from peft import LoraConfig
+
+ from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
+ from diffusers.loaders.peft import PeftAdapterMixin
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ if not issubclass(model.__class__, PeftAdapterMixin):
+ pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
+
+ denoiser_lora_config = LoraConfig(
+ r=4,
+ lora_alpha=4,
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
+ init_lora_weights=False,
+ use_dora=False,
+ )
+ model.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model.save_lora_adapter(tmpdir)
+ model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
+ self.assertTrue(os.path.isfile(model_file))
+
+ # Perturb the metadata in the state dict.
+ loaded_state_dict = safetensors.torch.load_file(model_file)
+ metadata = {"format": "pt"}
+ lora_adapter_metadata = denoiser_lora_config.to_dict()
+ lora_adapter_metadata.update({"foo": 1, "bar": 2})
+ for key, value in lora_adapter_metadata.items():
+ if isinstance(value, set):
+ lora_adapter_metadata[key] = list(value)
+ metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
+ safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
+
+ model.unload_lora()
+ self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
+
+ with self.assertRaises(TypeError) as err_context:
+ model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
+ self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
+
@require_torch_accelerator
def test_cpu_offload(self):
+ if self.model_class._no_split_modules is None:
+ pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
- if model._no_split_modules is None:
- return
-
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
+ base_normalized_output = normalize_output(base_output)
model_size = compute_module_sizes(model)[""]
- # We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
+
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
+
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
+
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
+ new_normalized_output = normalize_output(new_output)
- self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
+ self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_accelerator
def test_disk_offload_without_safetensors(self):
+ if self.model_class._no_split_modules is None:
+ pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
- if model._no_split_modules is None:
- return
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
+ base_normalized_output = normalize_output(base_output)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
@@ -1196,20 +1386,21 @@ def test_disk_offload_without_safetensors(self):
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
-
- self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
+ new_normalized_output = normalize_output(new_output)
+ self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_accelerator
def test_disk_offload_with_safetensors(self):
+ if self.model_class._no_split_modules is None:
+ pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
- if model._no_split_modules is None:
- return
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
+ base_normalized_output = normalize_output(base_output)
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1224,15 +1415,16 @@ def test_disk_offload_with_safetensors(self):
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
+ new_normalized_output = normalize_output(new_output)
- self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
+ self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_multi_accelerator
def test_model_parallelism(self):
+ if self.model_class._no_split_modules is None:
+ pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
- if model._no_split_modules is None:
- return
model = model.to(torch_device)
@@ -1250,7 +1442,6 @@ def test_model_parallelism(self):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
- print(f" new_model.hf_device_map:{new_model.hf_device_map}")
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
@@ -1267,6 +1458,7 @@ def test_sharded_checkpoints(self):
model = model.to(torch_device)
base_output = model(**inputs_dict)
+ base_normalized_output = normalize_output(base_output)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
@@ -1288,8 +1480,9 @@ def test_sharded_checkpoints(self):
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
+ new_normalized_output = normalize_output(new_output)
- self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
+ self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
@require_torch_accelerator
def test_sharded_checkpoints_with_variant(self):
@@ -1299,6 +1492,7 @@ def test_sharded_checkpoints_with_variant(self):
model = model.to(torch_device)
base_output = model(**inputs_dict)
+ base_normalized_output = normalize_output(base_output)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
@@ -1326,19 +1520,59 @@ def test_sharded_checkpoints_with_variant(self):
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
+ new_normalized_output = normalize_output(new_output)
+
+ self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
+
+ @require_torch_accelerator
+ def test_sharded_checkpoints_with_parallel_loading(self):
+ torch.manual_seed(0)
+ config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**config).eval()
+ model = model.to(torch_device)
+
+ base_output = model(**inputs_dict)
+ base_normalized_output = normalize_output(base_output)
+
+ model_size = compute_module_persistent_sizes(model)[""]
+ max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
+
+ # Now check if the right number of shards exists. First, let's get the number of shards.
+ # Since this number can be dependent on the model being tested, it's important that we calculate it
+ # instead of hardcoding it.
+ expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
+ actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
+ self.assertTrue(actual_num_shards == expected_num_shards)
+
+ # Load with parallel loading
+ os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes"
+ new_model = self.model_class.from_pretrained(tmp_dir).eval()
+ new_model = new_model.to(torch_device)
+
+ torch.manual_seed(0)
+ if "generator" in inputs_dict:
+ _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ new_output = new_model(**inputs_dict)
+ new_normalized_output = normalize_output(new_output)
- self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
+ self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
+ # set to no.
+ os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no"
@require_torch_accelerator
def test_sharded_checkpoints_device_map(self):
+ if self.model_class._no_split_modules is None:
+ pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
- if model._no_split_modules is None:
- return
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
+ base_normalized_output = normalize_output(base_output)
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
@@ -1359,7 +1593,9 @@ def test_sharded_checkpoints_device_map(self):
if "generator" in inputs_dict:
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
- self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
+ new_normalized_output = normalize_output(new_output)
+
+ self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
# This test is okay without a GPU because we're not running any execution. We're just serializing
# and check if the resultant files are following an expected format.
@@ -1402,7 +1638,7 @@ def test_variant_sharded_ckpt_right_format(self):
def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
- return
+ pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
@@ -1429,21 +1665,26 @@ def test_fn(storage_dtype, compute_dtype):
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
+ @torch.no_grad()
def test_layerwise_casting_inference(self):
- from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**config).eval()
- model = model.to(torch_device)
- base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
+ model = self.model_class(**config)
+ model.eval()
+ model.to(torch_device)
+ base_slice = model(**inputs_dict)[0]
+ base_slice = normalize_output(base_slice)
+ base_slice = base_slice.detach().flatten().cpu().numpy()
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
- if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
@@ -1462,7 +1703,9 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
check_linear_dtype(model, storage_dtype, compute_dtype)
- output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy()
+ output = model(**inputs_dict)[0]
+ output = normalize_output(output)
+ output = output.float().flatten().detach().cpu().numpy()
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
# We just want to make sure that the layerwise casting is working as expected.
@@ -1473,16 +1716,17 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
test_layerwise_casting(torch.float8_e5m2, torch.float32)
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
- @require_torch_gpu
+ @require_torch_accelerator
+ @torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats():
gc.collect()
- torch.cuda.synchronize()
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
+ backend_synchronize(torch_device)
+ backend_empty_cache(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
@@ -1495,7 +1739,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
- peak_inference_memory_allocated_mb = torch.cuda.max_memory_allocated() / 1024**2
+ peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
@@ -1505,7 +1749,7 @@ def get_memory_usage(storage_dtype, compute_dtype):
torch.float8_e4m3fn, torch.bfloat16
)
- compute_capability = get_torch_cuda_device_capability()
+ compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
@@ -1519,8 +1763,18 @@ def get_memory_usage(storage_dtype, compute_dtype):
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
)
- @require_torch_gpu
- def test_group_offloading(self):
+ @parameterized.expand([False, True])
+ @require_torch_accelerator
+ def test_group_offloading(self, record_stream):
+ for cls in inspect.getmro(self.__class__):
+ if "test_group_offloading" in cls.__dict__ and cls is not ModelTesterMixin:
+ # Skip this test if it is overwritten by child class. We need to do this because parameterized
+ # materializes the test methods on invocation which cannot be overridden.
+ pytest.skip("Model does not support group offloading.")
+
+ if not self.model_class._supports_group_offloading:
+ pytest.skip("Model does not support group offloading.")
+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
@@ -1537,37 +1791,233 @@ def run_forward(model):
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
- if not getattr(model, "_supports_group_offloading", True):
- return
-
model.to(torch_device)
output_without_group_offloading = run_forward(model)
+ output_without_group_offloading = normalize_output(output_without_group_offloading)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
+ output_with_group_offloading1 = normalize_output(output_with_group_offloading1)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
+ output_with_group_offloading2 = normalize_output(output_with_group_offloading2)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
+ output_with_group_offloading3 = normalize_output(output_with_group_offloading3)
torch.manual_seed(0)
model = self.model_class(**init_dict)
- model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True)
+ model.enable_group_offload(
+ torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
+ )
output_with_group_offloading4 = run_forward(model)
+ output_with_group_offloading4 = normalize_output(output_with_group_offloading4)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
+ @parameterized.expand([(False, "block_level"), (True, "leaf_level")])
+ @require_torch_accelerator
+ @torch.no_grad()
+ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
+ if not self.model_class._supports_group_offloading:
+ pytest.skip("Model does not support group offloading.")
+
+ torch.manual_seed(0)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+
+ model.to(torch_device)
+ model.eval()
+ _ = model(**inputs_dict)[0]
+
+ torch.manual_seed(0)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ storage_dtype, compute_dtype = torch.float16, torch.float32
+ inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
+ model = self.model_class(**init_dict)
+ model.eval()
+ additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
+ model.enable_group_offload(
+ torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
+ )
+ model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
+ _ = model(**inputs_dict)[0]
+
+ @parameterized.expand([("block_level", False), ("leaf_level", True)])
+ @require_torch_accelerator
+ @torch.no_grad()
+ @torch.inference_mode()
+ def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5):
+ for cls in inspect.getmro(self.__class__):
+ if "test_group_offloading_with_disk" in cls.__dict__ and cls is not ModelTesterMixin:
+ # Skip this test if it is overwritten by child class. We need to do this because parameterized
+ # materializes the test methods on invocation which cannot be overridden.
+ pytest.skip("Model does not support group offloading with disk yet.")
+
+ if not self.model_class._supports_group_offloading:
+ pytest.skip("Model does not support group offloading.")
+
+ def _has_generator_arg(model):
+ sig = inspect.signature(model.forward)
+ params = sig.parameters
+ return "generator" in params
+
+ def _run_forward(model, inputs_dict):
+ accepts_generator = _has_generator_arg(model)
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ torch.manual_seed(0)
+ return model(**inputs_dict)[0]
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+
+ model.eval()
+ model.to(torch_device)
+ output_without_group_offloading = _run_forward(model, inputs_dict)
+ output_without_group_offloading = normalize_output(output_without_group_offloading)
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict)
+ model.eval()
+
+ num_blocks_per_group = None if offload_type == "leaf_level" else 1
+ additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model.enable_group_offload(
+ torch_device,
+ offload_type=offload_type,
+ offload_to_disk_path=tmpdir,
+ use_stream=True,
+ record_stream=record_stream,
+ **additional_kwargs,
+ )
+ has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
+ self.assertTrue(has_safetensors, "No safetensors found in the directory.")
+
+ # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
+ # in nature. So, skip it.
+ if offload_type != "leaf_level":
+ is_correct, extra_files, missing_files = _check_safetensors_serialization(
+ module=model,
+ offload_to_disk_path=tmpdir,
+ offload_type=offload_type,
+ num_blocks_per_group=num_blocks_per_group,
+ block_modules=model._group_offload_block_modules
+ if hasattr(model, "_group_offload_block_modules")
+ else None,
+ )
+ if not is_correct:
+ if extra_files:
+ raise ValueError(f"Found extra files: {', '.join(extra_files)}")
+ elif missing_files:
+ raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
+
+ output_with_group_offloading = _run_forward(model, inputs_dict)
+ output_with_group_offloading = normalize_output(output_with_group_offloading)
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
+
+ def test_auto_model(self, expected_max_diff=5e-5):
+ if self.forward_requires_fresh_args:
+ model = self.model_class(**self.init_dict)
+ else:
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+
+ model = model.eval()
+ model = model.to(torch_device)
+
+ if hasattr(model, "set_default_attn_processor"):
+ model.set_default_attn_processor()
+
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
+ model.save_pretrained(tmpdirname, safe_serialization=False)
+
+ auto_model = AutoModel.from_pretrained(tmpdirname)
+ if hasattr(auto_model, "set_default_attn_processor"):
+ auto_model.set_default_attn_processor()
+
+ auto_model = auto_model.eval()
+ auto_model = auto_model.to(torch_device)
+
+ with torch.no_grad():
+ if self.forward_requires_fresh_args:
+ output_original = model(**self.inputs_dict(0))
+ output_auto = auto_model(**self.inputs_dict(0))
+ else:
+ output_original = model(**inputs_dict)
+ output_auto = auto_model(**inputs_dict)
+
+ if isinstance(output_original, dict):
+ output_original = output_original.to_tuple()[0]
+ if isinstance(output_auto, dict):
+ output_auto = output_auto.to_tuple()[0]
+
+ if isinstance(output_original, list):
+ output_original = torch.stack(output_original)
+ if isinstance(output_auto, list):
+ output_auto = torch.stack(output_auto)
+
+ output_original, output_auto = output_original.float(), output_auto.float()
+
+ max_diff = (output_original - output_auto).abs().max().item()
+ self.assertLessEqual(
+ max_diff,
+ expected_max_diff,
+ f"AutoModel forward pass diff: {max_diff} exceeds threshold {expected_max_diff}",
+ )
+
+ @parameterized.expand(
+ [
+ (-1, "You can't pass device_map as a negative int"),
+ ("foo", "When passing device_map as a string, the value needs to be a device name"),
+ ]
+ )
+ def test_wrong_device_map_raises_error(self, device_map, msg_substring):
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model.save_pretrained(tmpdir)
+ with self.assertRaises(ValueError) as err_ctx:
+ _ = self.model_class.from_pretrained(tmpdir, device_map=device_map)
+
+ assert msg_substring in str(err_ctx.exception)
+
+ @parameterized.expand([0, torch_device, torch.device(torch_device)])
+ @require_torch_accelerator
+ def test_passing_non_dict_device_map_works(self, device_map):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).eval()
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model.save_pretrained(tmpdir)
+ loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
+ _ = loaded_model(**inputs_dict)
+
+ @parameterized.expand([("", torch_device), ("", torch.device(torch_device))])
+ @require_torch_accelerator
+ def test_passing_dict_device_map_works(self, name, device):
+ # There are other valid dict-based `device_map` values too. It's best to refer to
+ # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).eval()
+ device_map = {name: device}
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model.save_pretrained(tmpdir)
+ loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
+ _ = loaded_model(**inputs_dict)
+
@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
@@ -1659,3 +2109,403 @@ def test_push_to_hub_library_name(self):
# Reset repo
delete_repo(self.repo_id, token=TOKEN)
+
+
+@require_torch_accelerator
+@require_torch_2
+@is_torch_compile
+@slow
+@require_torch_version_greater("2.7.1")
+class TorchCompileTesterMixin:
+ different_shapes_for_compilation = None
+
+ def setUp(self):
+ # clean up the VRAM before each test
+ super().setUp()
+ torch.compiler.reset()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ # clean up the VRAM after each test in case of CUDA runtime errors
+ super().tearDown()
+ torch.compiler.reset()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_torch_compile_recompilation_and_graph_break(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict).to(torch_device)
+ model.eval()
+ model = torch.compile(model, fullgraph=True)
+
+ with (
+ torch._inductor.utils.fresh_inductor_cache(),
+ torch._dynamo.config.patch(error_on_recompile=True),
+ torch.no_grad(),
+ ):
+ _ = model(**inputs_dict)
+ _ = model(**inputs_dict)
+
+ def test_torch_compile_repeated_blocks(self):
+ if self.model_class._repeated_blocks is None:
+ pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict).to(torch_device)
+ model.eval()
+ model.compile_repeated_blocks(fullgraph=True)
+
+ recompile_limit = 1
+ if self.model_class.__name__ == "UNet2DConditionModel":
+ recompile_limit = 2
+ elif self.model_class.__name__ == "ZImageTransformer2DModel":
+ recompile_limit = 3
+
+ with (
+ torch._inductor.utils.fresh_inductor_cache(),
+ torch._dynamo.config.patch(recompile_limit=recompile_limit),
+ torch.no_grad(),
+ ):
+ _ = model(**inputs_dict)
+ _ = model(**inputs_dict)
+
+ def test_compile_with_group_offloading(self):
+ if not self.model_class._supports_group_offloading:
+ pytest.skip("Model does not support group offloading.")
+
+ torch._dynamo.config.cache_size_limit = 10000
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.eval()
+ # TODO: Can test for other group offloading kwargs later if needed.
+ group_offload_kwargs = {
+ "onload_device": torch_device,
+ "offload_device": "cpu",
+ "offload_type": "block_level",
+ "num_blocks_per_group": 1,
+ "use_stream": True,
+ "non_blocking": True,
+ }
+ model.enable_group_offload(**group_offload_kwargs)
+ model.compile()
+
+ with torch.no_grad():
+ _ = model(**inputs_dict)
+ _ = model(**inputs_dict)
+
+ def test_compile_on_different_shapes(self):
+ if self.different_shapes_for_compilation is None:
+ pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
+ torch.fx.experimental._config.use_duck_shape = False
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+ model.eval()
+ model = torch.compile(model, fullgraph=True, dynamic=True)
+
+ for height, width in self.different_shapes_for_compilation:
+ with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
+ inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**inputs_dict)
+
+ def test_compile_works_with_aot(self):
+ from torch._inductor.package import load_package
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict).to(torch_device)
+ exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
+ _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
+ assert os.path.exists(package_path)
+ loaded_binary = load_package(package_path, run_single_threaded=True)
+
+ model.forward = loaded_binary
+
+ with torch.no_grad():
+ _ = model(**inputs_dict)
+ _ = model(**inputs_dict)
+
+
+@slow
+@require_torch_2
+@require_torch_accelerator
+@require_peft_backend
+@require_peft_version_greater("0.14.0")
+@require_torch_version_greater("2.7.1")
+@is_torch_compile
+class LoraHotSwappingForModelTesterMixin:
+ """Test that hotswapping does not result in recompilation on the model directly.
+
+ We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
+ tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require
+ recompilation.
+
+ See
+ https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
+ for the analogous PEFT test.
+
+ """
+
+ different_shapes_for_compilation = None
+
+ def tearDown(self):
+ # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
+ # there will be recompilation errors, as torch caches the model when run in the same process.
+ super().tearDown()
+ torch.compiler.reset()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_lora_config(self, lora_rank, lora_alpha, target_modules):
+ from peft import LoraConfig
+
+ lora_config = LoraConfig(
+ r=lora_rank,
+ lora_alpha=lora_alpha,
+ target_modules=target_modules,
+ init_lora_weights=False,
+ use_dora=False,
+ )
+ return lora_config
+
+ def get_linear_module_name_other_than_attn(self, model):
+ linear_names = [
+ name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
+ ]
+ return linear_names[0]
+
+ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
+ """
+ Check that hotswapping works on a small unet.
+
+ Steps:
+ - create 2 LoRA adapters and save them
+ - load the first adapter
+ - hotswap the second adapter
+ - check that the outputs are correct
+ - optionally compile the model
+ - optionally check if recompilations happen on different shapes
+
+ Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
+ fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
+ fine.
+ """
+ different_shapes = self.different_shapes_for_compilation
+ # create 2 adapters with different ranks and alphas
+ torch.manual_seed(0)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ alpha0, alpha1 = rank0, rank1
+ max_rank = max([rank0, rank1])
+ if target_modules1 is None:
+ target_modules1 = target_modules0[:]
+ lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
+ lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)
+
+ model.add_adapter(lora_config0, adapter_name="adapter0")
+ with torch.inference_mode():
+ torch.manual_seed(0)
+ output0_before = model(**inputs_dict)["sample"]
+
+ model.add_adapter(lora_config1, adapter_name="adapter1")
+ model.set_adapter("adapter1")
+ with torch.inference_mode():
+ torch.manual_seed(0)
+ output1_before = model(**inputs_dict)["sample"]
+
+ # sanity checks:
+ tol = 5e-3
+ assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
+ assert not (output0_before == 0).all()
+ assert not (output1_before == 0).all()
+
+ with tempfile.TemporaryDirectory() as tmp_dirname:
+ # save the adapter checkpoints
+ model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
+ model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
+ del model
+
+ # load the first adapter
+ torch.manual_seed(0)
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ if do_compile or (rank0 != rank1):
+ # no need to prepare if the model is not compiled or if the ranks are identical
+ model.enable_lora_hotswap(target_rank=max_rank)
+
+ file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
+ file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
+ model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
+
+ if do_compile:
+ model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
+
+ with torch.inference_mode():
+ # additionally check if dynamic compilation works.
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output0_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
+
+ # hotswap the 2nd adapter
+ model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
+
+ # we need to call forward to potentially trigger recompilation
+ with torch.inference_mode():
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output1_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
+
+ # check error when not passing valid adapter name
+ name = "does-not-exist"
+ msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
+ with self.assertRaisesRegex(ValueError, msg):
+ model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_model(self, rank0, rank1):
+ self.check_model_hotswap(
+ do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
+ )
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
+ # It's important to add this context to raise an error on recompilation
+ target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
+ self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
+ if "unet" not in self.model_class.__name__.lower():
+ pytest.skip("Test only applies to UNet.")
+
+ # It's important to add this context to raise an error on recompilation
+ target_modules = ["conv", "conv1", "conv2"]
+ with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
+ self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
+ if "unet" not in self.model_class.__name__.lower():
+ pytest.skip("Test only applies to UNet.")
+
+ # It's important to add this context to raise an error on recompilation
+ target_modules = ["to_q", "conv"]
+ with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
+ self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
+ # In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
+ # with `torch.compile()` for models that have both linear and conv layers. In this test, we check
+ # if we can target a linear layer from the transformer blocks and another linear layer from non-attention
+ # block.
+ target_modules = ["to_q"]
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+
+ target_modules.append(self.get_linear_module_name_other_than_attn(model))
+ del model
+
+ # It's important to add this context to raise an error on recompilation
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
+
+ def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
+ # ensure that enable_lora_hotswap is called before loading the first adapter
+ lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+ model.add_adapter(lora_config)
+
+ msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
+ with self.assertRaisesRegex(RuntimeError, msg):
+ model.enable_lora_hotswap(target_rank=32)
+
+ def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
+ # ensure that enable_lora_hotswap is called before loading the first adapter
+ from diffusers.loaders.peft import logger
+
+ lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+ model.add_adapter(lora_config)
+ msg = (
+ "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
+ )
+ with self.assertLogs(logger=logger, level="WARNING") as cm:
+ model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
+ assert any(msg in log for log in cm.output)
+
+ def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
+ # check possibility to ignore the error/warning
+ from diffusers.loaders.peft import logger
+
+ lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+ model.add_adapter(lora_config)
+ # note: assertNoLogs requires Python 3.10+
+ with self.assertNoLogs(logger, level="WARNING"):
+ model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
+
+ def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
+ # check that wrong argument value raises an error
+ lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+ model.add_adapter(lora_config)
+ msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
+ with self.assertRaisesRegex(ValueError, msg):
+ model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
+
+ def test_hotswap_second_adapter_targets_more_layers_raises(self):
+ # check the error and log
+ from diffusers.loaders.peft import logger
+
+ # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
+ target_modules0 = ["to_q"]
+ target_modules1 = ["to_q", "to_k"]
+ with self.assertRaises(RuntimeError): # peft raises RuntimeError
+ with self.assertLogs(logger=logger, level="ERROR") as cm:
+ self.check_model_hotswap(
+ do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
+ )
+ assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)])
+ @require_torch_version_greater("2.7.1")
+ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
+ different_shapes_for_compilation = self.different_shapes_for_compilation
+ if different_shapes_for_compilation is None:
+ pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
+ # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
+ # variable to represent input sizes that are the same. For more details,
+ # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
+ torch.fx.experimental._config.use_duck_shape = False
+
+ target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ self.check_model_hotswap(
+ do_compile=True,
+ rank0=rank0,
+ rank1=rank1,
+ target_modules0=target_modules,
+ )
diff --git a/tests/models/test_modeling_common_flax.py b/tests/models/test_modeling_common_flax.py
deleted file mode 100644
index 8945aed7c93f..000000000000
--- a/tests/models/test_modeling_common_flax.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import inspect
-
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import require_flax
-
-
-if is_flax_available():
- import jax
-
-
-@require_flax
-class FlaxModelTesterMixin:
- def test_output(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- model = self.model_class(**init_dict)
- variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
- jax.lax.stop_gradient(variables)
-
- output = model.apply(variables, inputs_dict["sample"])
-
- if isinstance(output, dict):
- output = output.sample
-
- self.assertIsNotNone(output)
- expected_shape = inputs_dict["sample"].shape
- self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
-
- def test_forward_with_norm_groups(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- init_dict["norm_num_groups"] = 16
- init_dict["block_out_channels"] = (16, 32)
-
- model = self.model_class(**init_dict)
- variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
- jax.lax.stop_gradient(variables)
-
- output = model.apply(variables, inputs_dict["sample"])
-
- if isinstance(output, dict):
- output = output.sample
-
- self.assertIsNotNone(output)
- expected_shape = inputs_dict["sample"].shape
- self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
-
- def test_deprecated_kwargs(self):
- has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
- has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
-
- if has_kwarg_in_model_class and not has_deprecated_kwarg:
- raise ValueError(
- f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
- " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
- " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
- " []`"
- )
-
- if not has_kwarg_in_model_class and has_deprecated_kwarg:
- raise ValueError(
- f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
- " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
- f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
- " from `_deprecated_kwargs = []`"
- )
diff --git a/tests/models/test_models_auto.py b/tests/models/test_models_auto.py
new file mode 100644
index 000000000000..a70754343f30
--- /dev/null
+++ b/tests/models/test_models_auto.py
@@ -0,0 +1,32 @@
+import unittest
+from unittest.mock import patch
+
+from transformers import CLIPTextModel, LongformerModel
+
+from diffusers.models import AutoModel, UNet2DConditionModel
+
+
+class TestAutoModel(unittest.TestCase):
+ @patch(
+ "diffusers.models.AutoModel.load_config",
+ side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}],
+ )
+ def test_load_from_config_diffusers_with_subfolder(self, mock_load_config):
+ model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
+ assert isinstance(model, UNet2DConditionModel)
+
+ @patch(
+ "diffusers.models.AutoModel.load_config",
+ side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
+ )
+ def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
+ model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
+ assert isinstance(model, CLIPTextModel)
+
+ def test_load_from_config_without_subfolder(self):
+ model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-longformer")
+ assert isinstance(model, LongformerModel)
+
+ def test_load_from_model_index(self):
+ model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
+ assert isinstance(model, CLIPTextModel)
diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py
index 5f4a2f587e92..473a87637578 100644
--- a/tests/models/transformers/test_models_dit_transformer2d.py
+++ b/tests/models/transformers/test_models_dit_transformer2d.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,13 +18,13 @@
import torch
from diffusers import DiTTransformer2DModel, Transformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
slow,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_pixart_transformer2d.py b/tests/models/transformers/test_models_pixart_transformer2d.py
index a544a3fc4607..17c400cf1911 100644
--- a/tests/models/transformers/test_models_pixart_transformer2d.py
+++ b/tests/models/transformers/test_models_pixart_transformer2d.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,13 +18,13 @@
import torch
from diffusers import PixArtTransformer2DModel, Transformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
slow,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_prior.py b/tests/models/transformers/test_models_prior.py
index 471c1084c00c..af5ac4bbbd76 100644
--- a/tests/models/transformers/test_models_prior.py
+++ b/tests/models/transformers/test_models_prior.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,8 @@
from parameterized import parameterized
from diffusers import PriorTransformer
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -29,7 +30,6 @@
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py
index 3479803da61d..7c002f87819e 100644
--- a/tests/models/transformers/test_models_transformer_allegro.py
+++ b/tests/models/transformers/test_models_transformer_allegro.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,11 +17,11 @@
import torch
from diffusers import AllegroTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py
index d1ff7d2c96d3..ae8c3b7234a3 100644
--- a/tests/models/transformers/test_models_transformer_aura_flow.py
+++ b/tests/models/transformers/test_models_transformer_aura_flow.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +18,8 @@
import torch
from diffusers import AuraFlowTransformer2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_bria.py b/tests/models/transformers/test_models_transformer_bria.py
new file mode 100644
index 000000000000..9056590edffe
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_bria.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import BriaTransformer2DModel
+from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
+from diffusers.models.embeddings import ImageProjection
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+def create_bria_ip_adapter_state_dict(model):
+ # "ip_adapter" (cross-attention weights)
+ ip_cross_attn_state_dict = {}
+ key_id = 0
+
+ for name in model.attn_processors.keys():
+ if name.startswith("single_transformer_blocks"):
+ continue
+
+ joint_attention_dim = model.config["joint_attention_dim"]
+ hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
+ sd = FluxIPAdapterJointAttnProcessor2_0(
+ hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
+ ).state_dict()
+ ip_cross_attn_state_dict.update(
+ {
+ f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
+ f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
+ f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
+ f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
+ }
+ )
+
+ key_id += 1
+
+ # "image_proj" (ImageProjection layer weights)
+
+ image_projection = ImageProjection(
+ cross_attention_dim=model.config["joint_attention_dim"],
+ image_embed_dim=model.config["pooled_projection_dim"],
+ num_image_text_embeds=4,
+ )
+
+ ip_image_projection_state_dict = {}
+ sd = image_projection.state_dict()
+ ip_image_projection_state_dict.update(
+ {
+ "proj.weight": sd["image_embeds.weight"],
+ "proj.bias": sd["image_embeds.bias"],
+ "norm.weight": sd["norm.weight"],
+ "norm.bias": sd["norm.bias"],
+ }
+ )
+
+ del sd
+ ip_state_dict = {}
+ ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
+ return ip_state_dict
+
+
+class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = BriaTransformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.8, 0.7, 0.7]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_latent_channels = 4
+ num_image_channels = 3
+ height = width = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
+ image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (16, 4)
+
+ @property
+ def output_shape(self):
+ return (16, 4)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 8,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "pooled_projection_dim": None,
+ "axes_dims_rope": [0, 4, 4],
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_deprecated_inputs_img_txt_ids_3d(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output_1 = model(**inputs_dict).to_tuple()[0]
+
+ # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
+ text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
+ image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
+
+ assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
+ assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
+
+ inputs_dict["txt_ids"] = text_ids_3d
+ inputs_dict["img_ids"] = image_ids_3d
+
+ with torch.no_grad():
+ output_2 = model(**inputs_dict).to_tuple()[0]
+
+ self.assertEqual(output_1.shape, output_2.shape)
+ self.assertTrue(
+ torch.allclose(output_1, output_2, atol=1e-5),
+ msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
+ )
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"BriaTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = BriaTransformer2DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
+
+
+class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
+ model_class = BriaTransformer2DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_bria_fibo.py b/tests/models/transformers/test_models_transformer_bria_fibo.py
new file mode 100644
index 000000000000..f859f4608bd5
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_bria_fibo.py
@@ -0,0 +1,89 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import BriaFiboTransformer2DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = BriaFiboTransformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.8, 0.7, 0.7]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_latent_channels = 48
+ num_image_channels = 3
+ height = width = 16
+ sequence_length = 32
+ embedding_dim = 64
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
+ image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "timestep": timestep,
+ "text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
+ }
+
+ @property
+ def input_shape(self):
+ return (16, 16)
+
+ @property
+ def output_shape(self):
+ return (256, 48)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 1,
+ "in_channels": 48,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 8,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 64,
+ "text_encoder_dim": 32,
+ "pooled_projection_dim": None,
+ "axes_dims_rope": [0, 4, 4],
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"BriaFiboTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py
new file mode 100644
index 000000000000..92ac8198ed06
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_chroma.py
@@ -0,0 +1,183 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import ChromaTransformer2DModel
+from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
+from diffusers.models.embeddings import ImageProjection
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+def create_chroma_ip_adapter_state_dict(model):
+ # "ip_adapter" (cross-attention weights)
+ ip_cross_attn_state_dict = {}
+ key_id = 0
+
+ for name in model.attn_processors.keys():
+ if name.startswith("single_transformer_blocks"):
+ continue
+
+ joint_attention_dim = model.config["joint_attention_dim"]
+ hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
+ sd = FluxIPAdapterJointAttnProcessor2_0(
+ hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
+ ).state_dict()
+ ip_cross_attn_state_dict.update(
+ {
+ f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
+ f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
+ f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
+ f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
+ }
+ )
+
+ key_id += 1
+
+ # "image_proj" (ImageProjection layer weights)
+
+ image_projection = ImageProjection(
+ cross_attention_dim=model.config["joint_attention_dim"],
+ image_embed_dim=model.config["pooled_projection_dim"],
+ num_image_text_embeds=4,
+ )
+
+ ip_image_projection_state_dict = {}
+ sd = image_projection.state_dict()
+ ip_image_projection_state_dict.update(
+ {
+ "proj.weight": sd["image_embeds.weight"],
+ "proj.bias": sd["image_embeds.bias"],
+ "norm.weight": sd["norm.weight"],
+ "norm.bias": sd["norm.bias"],
+ }
+ )
+
+ del sd
+ ip_state_dict = {}
+ ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
+ return ip_state_dict
+
+
+class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = ChromaTransformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.8, 0.7, 0.7]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_latent_channels = 4
+ num_image_channels = 3
+ height = width = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
+ image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (16, 4)
+
+ @property
+ def output_shape(self):
+ return (16, 4)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "axes_dims_rope": [4, 4, 8],
+ "approximator_num_channels": 8,
+ "approximator_hidden_dim": 16,
+ "approximator_layers": 1,
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_deprecated_inputs_img_txt_ids_3d(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output_1 = model(**inputs_dict).to_tuple()[0]
+
+ # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
+ text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
+ image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
+
+ assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
+ assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
+
+ inputs_dict["txt_ids"] = text_ids_3d
+ inputs_dict["img_ids"] = image_ids_3d
+
+ with torch.no_grad():
+ output_2 = model(**inputs_dict).to_tuple()[0]
+
+ self.assertEqual(output_1.shape, output_2.shape)
+ self.assertTrue(
+ torch.allclose(output_1, output_2, atol=1e-5),
+ msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
+ )
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"ChromaTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class ChromaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = ChromaTransformer2DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
+
+
+class ChromaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
+ model_class = ChromaTransformer2DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py
index 2b3cca883d17..f632add7e5a7 100644
--- a/tests/models/transformers/test_models_transformer_cogvideox.py
+++ b/tests/models/transformers/test_models_transformer_cogvideox.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,11 @@
import torch
from diffusers import CogVideoXTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py
index 91c7c35fbd07..d38d77531d4c 100644
--- a/tests/models/transformers/test_models_transformer_cogview3plus.py
+++ b/tests/models/transformers/test_models_transformer_cogview3plus.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,11 @@
import torch
from diffusers import CogView3PlusTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_cogview4.py b/tests/models/transformers/test_models_transformer_cogview4.py
index e311ce77ea50..084c3b7cea41 100644
--- a/tests/models/transformers/test_models_transformer_cogview4.py
+++ b/tests/models/transformers/test_models_transformer_cogview4.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,8 +17,8 @@
import torch
from diffusers import CogView4Transformer2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_consisid.py b/tests/models/transformers/test_models_transformer_consisid.py
index b848ed014074..77fc172d078a 100644
--- a/tests/models/transformers/test_models_transformer_consisid.py
+++ b/tests/models/transformers/test_models_transformer_consisid.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,11 @@
import torch
from diffusers import ConsisIDTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_cosmos.py b/tests/models/transformers/test_models_transformer_cosmos.py
new file mode 100644
index 000000000000..d7390e105c45
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_cosmos.py
@@ -0,0 +1,153 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import CosmosTransformer3DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
+ model_class = CosmosTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 1
+ height = 16
+ width = 16
+ text_embed_dim = 16
+ sequence_length = 12
+ fps = 30
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
+ attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
+ padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "attention_mask": attention_mask,
+ "fps": fps,
+ "padding_mask": padding_mask,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "num_layers": 2,
+ "mlp_ratio": 2,
+ "text_embed_dim": 16,
+ "adaln_lora_dim": 4,
+ "max_size": (4, 32, 32),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (2.0, 1.0, 1.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": "learnable",
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"CosmosTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
+ model_class = CosmosTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 1
+ height = 16
+ width = 16
+ text_embed_dim = 16
+ sequence_length = 12
+ fps = 30
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
+ attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
+ condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
+ padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "attention_mask": attention_mask,
+ "fps": fps,
+ "condition_mask": condition_mask,
+ "padding_mask": padding_mask,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 4 + 1,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "num_layers": 2,
+ "mlp_ratio": 2,
+ "text_embed_dim": 16,
+ "adaln_lora_dim": 4,
+ "max_size": (4, 32, 32),
+ "patch_size": (1, 2, 2),
+ "rope_scale": (2.0, 1.0, 1.0),
+ "concat_padding_mask": True,
+ "extra_pos_embed_type": "learnable",
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"CosmosTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_easyanimate.py b/tests/models/transformers/test_models_transformer_easyanimate.py
index 9f10a7da0a76..d7b90a47d974 100644
--- a/tests/models/transformers/test_models_transformer_easyanimate.py
+++ b/tests/models/transformers/test_models_transformer_easyanimate.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +18,8 @@
import torch
from diffusers import EasyAnimateTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index c88b3dac8216..3ab02f797b5b 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,9 +20,9 @@
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-from ..test_modeling_common import ModelTesterMixin
+from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
+from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
@@ -57,7 +57,9 @@ def create_flux_ip_adapter_state_dict(model):
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
- image_embed_dim=model.config["pooled_projection_dim"],
+ image_embed_dim=(
+ model.config["pooled_projection_dim"] if "pooled_projection_dim" in model.config.keys() else 768
+ ),
num_image_text_embeds=4,
)
@@ -89,10 +91,20 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
@property
def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 4)
+
+ @property
+ def output_shape(self):
+ return (16, 4)
+
+ def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
- height = width = 4
sequence_length = 48
embedding_dim = 32
@@ -112,14 +124,6 @@ def dummy_input(self):
"timestep": timestep,
}
- @property
- def input_shape(self):
- return (16, 4)
-
- @property
- def output_shape(self):
- return (16, 4)
-
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
@@ -167,3 +171,54 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ # The test exists for cases like
+ # https://github.com/huggingface/diffusers/issues/11874
+ @unittest.skipIf(not is_peft_available(), "Only with PEFT")
+ def test_lora_exclude_modules(self):
+ from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
+
+ lora_rank = 4
+ target_module = "single_transformer_blocks.0.proj_out"
+ adapter_name = "foo"
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ state_dict = model.state_dict()
+ target_mod_shape = state_dict[f"{target_module}.weight"].shape
+ lora_state_dict = {
+ f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
+ f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
+ }
+ # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
+ config = LoraConfig(
+ r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
+ )
+ inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
+ set_peft_model_state_dict(model, lora_state_dict, adapter_name)
+ retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
+ assert len(retrieved_lora_state_dict) == len(lora_state_dict)
+ assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
+ assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
+
+
+class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = FluxTransformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
+
+
+class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
+ model_class = FluxTransformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
diff --git a/tests/models/transformers/test_models_transformer_flux2.py b/tests/models/transformers/test_models_transformer_flux2.py
new file mode 100644
index 000000000000..316d5fa770bb
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_flux2.py
@@ -0,0 +1,162 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import Flux2Transformer2DModel, attention_backend
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = Flux2Transformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.7, 0.6, 0.6]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 4)
+
+ @property
+ def output_shape(self):
+ return (16, 4)
+
+ def prepare_dummy_input(self, height=4, width=4):
+ batch_size = 1
+ num_latent_channels = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+
+ t_coords = torch.arange(1)
+ h_coords = torch.arange(height)
+ w_coords = torch.arange(width)
+ l_coords = torch.arange(1)
+ image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords) # [height * width, 4]
+ image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
+
+ text_t_coords = torch.arange(1)
+ text_h_coords = torch.arange(1)
+ text_w_coords = torch.arange(1)
+ text_l_coords = torch.arange(sequence_length)
+ text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
+ text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
+
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+ guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "timestep": timestep,
+ "guidance": guidance,
+ }
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 16,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "timestep_guidance_channels": 256, # Hardcoded in original code
+ "axes_dims_rope": [4, 4, 4, 4],
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ # TODO (Daniel, Sayak): We can remove this test.
+ def test_flux2_consistency(self, seed=0):
+ torch.manual_seed(seed)
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(seed)
+ model = self.model_class(**init_dict)
+ # state_dict = model.state_dict()
+ # for key, param in state_dict.items():
+ # print(f"{key} | {param.shape}")
+ # torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
+ model.to(torch_device)
+ model.eval()
+
+ with attention_backend("native"):
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ self.assertIsNotNone(output)
+
+ # input & output have to have the same shape
+ input_tensor = inputs_dict[self.main_input_name]
+ expected_shape = input_tensor.shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ # Check against expected slice
+ # fmt: off
+ expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
+ # fmt: on
+
+ flat_output = output.cpu().flatten()
+ generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"Flux2Transformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = Flux2Transformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
+
+
+class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
+ model_class = Flux2Transformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
diff --git a/tests/models/transformers/test_models_transformer_hidream.py b/tests/models/transformers/test_models_transformer_hidream.py
new file mode 100644
index 000000000000..fdd5f8c7fd07
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_hidream.py
@@ -0,0 +1,96 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import HiDreamImageTransformer2DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = HiDreamImageTransformer2DModel
+ main_input_name = "hidden_states"
+ model_split_percents = [0.8, 0.8, 0.9]
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ height = width = 32
+ embedding_dim_t5, embedding_dim_llama, embedding_dim_pooled = 8, 4, 8
+ sequence_length = 8
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length, embedding_dim_t5)).to(torch_device)
+ encoder_hidden_states_llama3 = torch.randn((batch_size, batch_size, sequence_length, embedding_dim_llama)).to(
+ torch_device
+ )
+ pooled_embeds = torch.randn((batch_size, embedding_dim_pooled)).to(torch_device)
+ timesteps = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states_t5": encoder_hidden_states_t5,
+ "encoder_hidden_states_llama3": encoder_hidden_states_llama3,
+ "pooled_embeds": pooled_embeds,
+ "timesteps": timesteps,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (4, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 2,
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 8,
+ "num_attention_heads": 4,
+ "caption_channels": [8, 4],
+ "text_emb_dim": 8,
+ "num_routed_experts": 2,
+ "num_activated_experts": 2,
+ "axes_dims_rope": (4, 2, 2),
+ "max_resolution": (32, 32),
+ "llama_layers": (0, 1),
+ "force_inference_output": True, # TODO: as we don't implement MoE loss in training tests.
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skip("HiDreamImageTransformer2DModel uses a dedicated attention processor. This test doesn't apply")
+ def test_set_attn_processor_for_determinism(self):
+ pass
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"HiDreamImageTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_hunyuan_1_5.py b/tests/models/transformers/test_models_transformer_hunyuan_1_5.py
new file mode 100644
index 000000000000..57080bc5b0b4
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_hunyuan_1_5.py
@@ -0,0 +1,101 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import HunyuanVideo15Transformer3DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideo15Transformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+ model_split_percents = [0.99, 0.99, 0.99]
+
+ text_embed_dim = 16
+ text_embed_2_dim = 8
+ image_embed_dim = 12
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 1
+ height = 8
+ width = 8
+ sequence_length = 6
+ sequence_length_2 = 4
+ image_sequence_length = 3
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device)
+ encoder_hidden_states_2 = torch.randn(
+ (batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device
+ )
+ encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device)
+ encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device)
+ # All zeros for inducing T2V path in the model.
+ image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "encoder_hidden_states_2": encoder_hidden_states_2,
+ "encoder_attention_mask_2": encoder_attention_mask_2,
+ "image_embeds": image_embeds,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 8, 8)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 8, 8)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "num_layers": 2,
+ "num_refiner_layers": 1,
+ "mlp_ratio": 2.0,
+ "patch_size": 1,
+ "patch_size_t": 1,
+ "text_embed_dim": self.text_embed_dim,
+ "text_embed_2_dim": self.text_embed_2_dim,
+ "image_embed_dim": self.image_embed_dim,
+ "rope_axes_dim": (2, 2, 4),
+ "target_size": 16,
+ "task_type": "t2v",
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"HunyuanVideo15Transformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py
index ea05abed38d9..d82a62d58ec3 100644
--- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py
+++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,11 @@
import torch
from diffusers import HunyuanDiT2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py
index 495131ad6fd8..385a5eefd58b 100644
--- a/tests/models/transformers/test_models_transformer_hunyuan_video.py
+++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,9 +17,12 @@
import torch
from diffusers import HunyuanVideoTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-from ..test_modeling_common import ModelTesterMixin
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
@@ -90,6 +93,13 @@ def test_gradient_checkpointing_is_applied(self):
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
+
+
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
@@ -158,6 +168,13 @@ def test_gradient_checkpointing_is_applied(self):
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
+
+
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
@@ -224,6 +241,13 @@ def test_gradient_checkpointing_is_applied(self):
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
+
+
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
@@ -290,3 +314,10 @@ def test_output(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py
new file mode 100644
index 000000000000..00a2b27e02b6
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py
@@ -0,0 +1,116 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import HunyuanVideoFramepackTransformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = HunyuanVideoFramepackTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+ model_split_percents = [0.5, 0.7, 0.9]
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 3
+ height = 4
+ width = 4
+ text_encoder_embedding_dim = 16
+ image_encoder_embedding_dim = 16
+ pooled_projection_dim = 8
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+ pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
+ encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
+ image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device)
+ indices_latents = torch.ones((3,)).to(torch_device)
+ latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
+ indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device)
+ latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
+ indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device)
+ latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to(
+ torch_device
+ )
+ indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_projections,
+ "encoder_attention_mask": encoder_attention_mask,
+ "guidance": guidance,
+ "image_embeds": image_embeds,
+ "indices_latents": indices_latents,
+ "latents_clean": latents_clean,
+ "indices_latents_clean": indices_latents_clean,
+ "latents_history_2x": latents_history_2x,
+ "indices_latents_history_2x": indices_latents_history_2x,
+ "latents_history_4x": latents_history_4x,
+ "indices_latents_history_4x": indices_latents_history_4x,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 3, 4, 4)
+
+ @property
+ def output_shape(self):
+ return (4, 3, 4, 4)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 4,
+ "out_channels": 4,
+ "num_attention_heads": 2,
+ "attention_head_dim": 10,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "num_refiner_layers": 1,
+ "patch_size": 2,
+ "patch_size_t": 1,
+ "guidance_embeds": True,
+ "text_embed_dim": 16,
+ "pooled_projection_dim": 8,
+ "rope_axes_dim": (2, 4, 4),
+ "image_condition_type": None,
+ "has_image_proj": True,
+ "image_proj_dim": 16,
+ "has_clean_x_embedder": True,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"HunyuanVideoFramepackTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_latte.py b/tests/models/transformers/test_models_transformer_latte.py
index 0cb9094f5165..7bf2c52e6269 100644
--- a/tests/models/transformers/test_models_transformer_latte.py
+++ b/tests/models/transformers/test_models_transformer_latte.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,11 @@
import torch
from diffusers import LatteTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py
index 128bf04155e7..e912463bbf6a 100644
--- a/tests/models/transformers/test_models_transformer_ltx.py
+++ b/tests/models/transformers/test_models_transformer_ltx.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,9 +18,9 @@
import torch
from diffusers import LTXVideoTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-from ..test_modeling_common import ModelTesterMixin
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
@@ -81,3 +81,10 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTXVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = LTXVideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_lumina.py b/tests/models/transformers/test_models_transformer_lumina.py
index 6744fb8ac84b..0024aa106c6d 100644
--- a/tests/models/transformers/test_models_transformer_lumina.py
+++ b/tests/models/transformers/test_models_transformer_lumina.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,11 @@
import torch
from diffusers import LuminaNextDiT2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py
index 4db3ae68aa94..4efae3d4b713 100644
--- a/tests/models/transformers/test_models_transformer_lumina2.py
+++ b/tests/models/transformers/test_models_transformer_lumina2.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,11 @@
import torch
from diffusers import Lumina2Transformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_mochi.py b/tests/models/transformers/test_models_transformer_mochi.py
index d284ab942949..931b5874ee78 100644
--- a/tests/models/transformers/test_models_transformer_mochi.py
+++ b/tests/models/transformers/test_models_transformer_mochi.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +18,8 @@
import torch
from diffusers import MochiTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_omnigen.py b/tests/models/transformers/test_models_transformer_omnigen.py
index 1bdcc68b0378..f1963ddb7709 100644
--- a/tests/models/transformers/test_models_transformer_omnigen.py
+++ b/tests/models/transformers/test_models_transformer_omnigen.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +18,8 @@
import torch
from diffusers import OmniGenTransformer2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_prx.py b/tests/models/transformers/test_models_transformer_prx.py
new file mode 100644
index 000000000000..1387625d5ea0
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_prx.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class PRXTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = PRXTransformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (16, 16, 16)
+
+ def prepare_dummy_input(self, height=16, width=16):
+ batch_size = 1
+ num_latent_channels = 16
+ sequence_length = 16
+ embedding_dim = 1792
+
+ hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ }
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 16,
+ "patch_size": 2,
+ "context_in_dim": 1792,
+ "hidden_size": 1792,
+ "mlp_ratio": 3.5,
+ "num_heads": 28,
+ "depth": 4, # Smaller depth for testing
+ "axes_dim": [32, 32],
+ "theta": 10_000,
+ }
+ inputs_dict = self.prepare_dummy_input()
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"PRXTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py
new file mode 100644
index 000000000000..b24fa90503ef
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_qwenimage.py
@@ -0,0 +1,106 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import pytest
+import torch
+
+from diffusers import QwenImageTransformer2DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = QwenImageTransformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.7, 0.6, 0.6]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 16)
+
+ @property
+ def output_shape(self):
+ return (16, 16)
+
+ def prepare_dummy_input(self, height=4, width=4):
+ batch_size = 1
+ num_latent_channels = embedding_dim = 16
+ sequence_length = 7
+ vae_scale_factor = 4
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+ orig_height = height * 2 * vae_scale_factor
+ orig_width = width * 2 * vae_scale_factor
+ img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_hidden_states_mask": encoder_hidden_states_mask,
+ "timestep": timestep,
+ "img_shapes": img_shapes,
+ "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
+ }
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 2,
+ "in_channels": 16,
+ "out_channels": 4,
+ "num_layers": 2,
+ "attention_head_dim": 16,
+ "num_attention_heads": 3,
+ "joint_attention_dim": 16,
+ "guidance_embeds": False,
+ "axes_dims_rope": (8, 4, 4),
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"QwenImageTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = QwenImageTransformer2DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
+
+ @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
+ def test_torch_compile_recompilation_and_graph_break(self):
+ super().test_torch_compile_recompilation_and_graph_break()
diff --git a/tests/models/transformers/test_models_transformer_sana.py b/tests/models/transformers/test_models_transformer_sana.py
index d4dc30f5d7a8..2e316c3aedc1 100644
--- a/tests/models/transformers/test_models_transformer_sana.py
+++ b/tests/models/transformers/test_models_transformer_sana.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,11 +17,11 @@
import torch
from diffusers import SanaTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_sana_video.py b/tests/models/transformers/test_models_transformer_sana_video.py
new file mode 100644
index 000000000000..ff564ed8918d
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_sana_video.py
@@ -0,0 +1,97 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import SanaVideoTransformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = SanaVideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 16
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (16, 2, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (16, 2, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "num_layers": 2,
+ "num_cross_attention_heads": 2,
+ "cross_attention_head_dim": 12,
+ "cross_attention_dim": 24,
+ "caption_channels": 16,
+ "mlp_ratio": 2.5,
+ "dropout": 0.0,
+ "attention_bias": False,
+ "sample_size": 8,
+ "patch_size": (1, 2, 2),
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"SanaVideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = SanaVideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py
index 659d9a82fd76..c4ee7017a380 100644
--- a/tests/models/transformers/test_models_transformer_sd3.py
+++ b/tests/models/transformers/test_models_transformer_sd3.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,11 +19,11 @@
from diffusers import SD3Transformer2DModel
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
@@ -92,9 +92,9 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()
- assert (
- model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
- ), "xformers is not enabled"
+ assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
+ "xformers is not enabled"
+ )
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
@@ -167,9 +167,9 @@ def test_xformers_enable_works(self):
model.enable_xformers_memory_efficient_attention()
- assert (
- model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor"
- ), "xformers is not enabled"
+ assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
+ "xformers is not enabled"
+ )
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
diff --git a/tests/models/transformers/test_models_transformer_skyreels_v2.py b/tests/models/transformers/test_models_transformer_skyreels_v2.py
new file mode 100644
index 000000000000..8c36d8256ee9
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_skyreels_v2.py
@@ -0,0 +1,84 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import SkyReelsV2Transformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2Transformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
+ model_class = SkyReelsV2Transformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"SkyReelsV2Transformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_temporal.py b/tests/models/transformers/test_models_transformer_temporal.py
index 7b689447cf29..aff83be51124 100644
--- a/tests/models/transformers/test_models_transformer_temporal.py
+++ b/tests/models/transformers/test_models_transformer_temporal.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,11 +18,11 @@
import torch
from diffusers.models.transformers import TransformerTemporalModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py
index 3ac64c628988..9f248f990c8a 100644
--- a/tests/models/transformers/test_models_transformer_wan.py
+++ b/tests/models/transformers/test_models_transformer_wan.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,9 +17,12 @@
import torch
from diffusers import WanTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-from ..test_modeling_common import ModelTesterMixin
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
@@ -79,3 +82,10 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = WanTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_wan_animate.py b/tests/models/transformers/test_models_transformer_wan_animate.py
new file mode 100644
index 000000000000..5d571b8c2e7d
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_wan_animate.py
@@ -0,0 +1,126 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import WanAnimateTransformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = WanAnimateTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ clip_seq_len = 12
+ clip_dim = 16
+
+ inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
+ face_height = 16 # Should be square and match `motion_encoder_size` below
+ face_width = 16
+
+ hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+ clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
+ pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
+ torch_device
+ )
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_hidden_states_image": clip_ref_features,
+ "pose_hidden_states": pose_latents,
+ "face_pixel_values": face_pixel_values,
+ }
+
+ @property
+ def input_shape(self):
+ return (12, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ # Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
+ # contain the vast majority of the parameters in the test model
+ channel_sizes = {"4": 16, "8": 16, "16": 16}
+
+ init_dict = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 12, # 2 * C + 4 = 2 * 4 + 4 = 12
+ "latent_channels": 4,
+ "out_channels": 4,
+ "text_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "image_dim": 16,
+ "rope_max_seq_len": 32,
+ "motion_encoder_channel_sizes": channel_sizes, # Start of Wan Animate-specific config
+ "motion_encoder_size": 16, # Ensures that there will be 2 motion encoder resblocks
+ "motion_style_dim": 8,
+ "motion_dim": 4,
+ "motion_encoder_dim": 16,
+ "face_encoder_hidden_dim": 16,
+ "face_encoder_num_heads": 2,
+ "inject_face_latents_blocks": 2,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"WanAnimateTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ # Override test_output because the transformer output is expected to have less channels than the main transformer
+ # input.
+ def test_output(self):
+ expected_output_shape = (1, 4, 21, 16, 16)
+ super().test_output(expected_output_shape=expected_output_shape)
+
+
+class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = WanAnimateTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_z_image.py b/tests/models/transformers/test_models_transformer_z_image.py
new file mode 100644
index 000000000000..79054019f2d2
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_z_image.py
@@ -0,0 +1,171 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+import unittest
+
+import torch
+
+from diffusers import ZImageTransformer2DModel
+
+from ...testing_utils import IS_GITHUB_ACTIONS, torch_device
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
+# Cannot use enable_full_determinism() which sets it to True
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+torch.use_deterministic_algorithms(False)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+if hasattr(torch.backends, "cuda"):
+ torch.backends.cuda.matmul.allow_tf32 = False
+
+
+@unittest.skipIf(
+ IS_GITHUB_ACTIONS,
+ reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.",
+)
+class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = ZImageTransformer2DModel
+ main_input_name = "x"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.9, 0.9, 0.9]
+
+ def prepare_dummy_input(self, height=16, width=16):
+ batch_size = 1
+ num_channels = 16
+ embedding_dim = 16
+ sequence_length = 16
+
+ hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)]
+ encoder_hidden_states = [
+ torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size)
+ ]
+ timestep = torch.tensor([0.0]).to(torch_device)
+
+ return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
+
+ @property
+ def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (4, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (4, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "all_patch_size": (2,),
+ "all_f_patch_size": (1,),
+ "in_channels": 16,
+ "dim": 16,
+ "n_layers": 1,
+ "n_refiner_layers": 1,
+ "n_heads": 1,
+ "n_kv_heads": 2,
+ "qk_norm": True,
+ "cap_feat_dim": 16,
+ "rope_theta": 256.0,
+ "t_scale": 1000.0,
+ "axes_dims": [8, 4, 4],
+ "axes_lens": [256, 32, 32],
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def setUp(self):
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"ZImageTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+ @unittest.skip("Test is not supported for handling main inputs that are lists.")
+ def test_training(self):
+ super().test_training()
+
+ @unittest.skip("Test is not supported for handling main inputs that are lists.")
+ def test_ema_training(self):
+ super().test_ema_training()
+
+ @unittest.skip("Test is not supported for handling main inputs that are lists.")
+ def test_effective_gradient_checkpointing(self):
+ super().test_effective_gradient_checkpointing()
+
+ @unittest.skip(
+ "Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices."
+ )
+ def test_layerwise_casting_training(self):
+ super().test_layerwise_casting_training()
+
+ @unittest.skip("Test is not supported for handling main inputs that are lists.")
+ def test_outputs_equivalence(self):
+ super().test_outputs_equivalence()
+
+ @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
+ def test_group_offloading(self):
+ super().test_group_offloading()
+
+ @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
+ def test_group_offloading_with_disk(self):
+ super().test_group_offloading_with_disk()
+
+
+class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = ZImageTransformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return ZImageTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)
+
+ @unittest.skip(
+ "The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice."
+ )
+ def test_torch_compile_recompilation_and_graph_break(self):
+ super().test_torch_compile_recompilation_and_graph_break()
+
+ @unittest.skip("Fullgraph AoT is broken")
+ def test_compile_works_with_aot(self):
+ super().test_compile_works_with_aot()
+
+ @unittest.skip("Fullgraph is broken")
+ def test_compile_on_different_shapes(self):
+ super().test_compile_on_different_shapes()
diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py
index 7e160f9c128b..bac017e7e7d3 100644
--- a/tests/models/unets/test_models_unet_1d.py
+++ b/tests/models/unets/test_models_unet_1d.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,13 +19,13 @@
import torch
from diffusers import UNet1DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_manual_seed,
floats_tensor,
slow,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py
index 0e5fdc4bba2e..e289f44303f2 100644
--- a/tests/models/unets/test_models_unet_2d.py
+++ b/tests/models/unets/test_models_unet_2d.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,9 @@
from diffusers import UNet2DModel
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
@@ -29,7 +31,6 @@
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
@@ -229,7 +230,7 @@ def test_from_pretrained_accelerate_wont_change_results(self):
# two models don't need to stay in the device at the same time
del model_accelerate
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained(
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index 8e1187f11468..4dbb8ca7c075 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,8 @@
from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -46,14 +47,17 @@
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
- require_torch_gpu,
skip_mps,
slow,
torch_all_close,
torch_device,
)
-
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import (
+ LoraHotSwappingForModelTesterMixin,
+ ModelTesterMixin,
+ TorchCompileTesterMixin,
+ UNetTesterMixin,
+)
if is_peft_available():
@@ -354,7 +358,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
- model_split_percents = [0.5, 0.3, 0.4]
+ model_split_percents = [0.5, 0.34, 0.4]
@property
def dummy_input(self):
@@ -654,22 +658,22 @@ def test_model_xattn_mask(self, mask_dtype):
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
- assert full_cond_keepallmask_out.allclose(
- full_cond_out, rtol=1e-05, atol=1e-05
- ), "a 'keep all' mask should give the same result as no mask"
+ assert full_cond_keepallmask_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
+ "a 'keep all' mask should give the same result as no mask"
+ )
trunc_cond = cond[:, :-1, :]
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
- assert not trunc_cond_out.allclose(
- full_cond_out, rtol=1e-05, atol=1e-05
- ), "discarding the last token from our cond should change the result"
+ assert not trunc_cond_out.allclose(full_cond_out, rtol=1e-05, atol=1e-05), (
+ "discarding the last token from our cond should change the result"
+ )
batch, tokens, _ = cond.shape
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
- assert masked_cond_out.allclose(
- trunc_cond_out, rtol=1e-05, atol=1e-05
- ), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
+ assert masked_cond_out.allclose(trunc_cond_out, rtol=1e-05, atol=1e-05), (
+ "masking the last token from our cond should be equivalent to truncating that token out of the condition"
+ )
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
@@ -697,9 +701,9 @@ def test_model_xattn_padding(self):
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
- assert trunc_mask_out.allclose(
- keeplast_out
- ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
+ assert trunc_mask_out.allclose(keeplast_out), (
+ "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
+ )
def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing
@@ -973,13 +977,13 @@ def test_ip_adapter_plus(self):
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
- @require_torch_gpu
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
+ @require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
@@ -989,13 +993,13 @@ def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
- @require_torch_gpu
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
+ @require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
@@ -1114,12 +1118,12 @@ def test_load_attn_procs_raise_warning(self):
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
- assert not torch.allclose(
- non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
- ), "LoRA injected UNet should produce different results."
- assert torch.allclose(
- lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
- ), "Loading from a saved checkpoint should produce identical results."
+ assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
+ "LoRA injected UNet should produce different results."
+ )
+ assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
+ "Loading from a saved checkpoint should produce identical results."
+ )
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
@@ -1140,6 +1144,20 @@ def test_save_attn_procs_raise_warning(self):
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
+class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = UNet2DConditionModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
+
+
+class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
+ model_class = UNet2DConditionModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
+
+
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
diff --git a/tests/models/unets/test_models_unet_2d_flax.py b/tests/models/unets/test_models_unet_2d_flax.py
deleted file mode 100644
index 69a0704dca9d..000000000000
--- a/tests/models/unets/test_models_unet_2d_flax.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import gc
-import unittest
-
-from parameterized import parameterized
-
-from diffusers import FlaxUNet2DConditionModel
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
-
-
-@slow
-@require_flax
-class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
- def get_file_format(self, seed, shape):
- return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
- dtype = jnp.bfloat16 if fp16 else jnp.float32
- image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
- return image
-
- def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
- dtype = jnp.bfloat16 if fp16 else jnp.float32
- revision = "bf16" if fp16 else None
-
- model, params = FlaxUNet2DConditionModel.from_pretrained(
- model_id, subfolder="unet", dtype=dtype, revision=revision
- )
- return model, params
-
- def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
- dtype = jnp.bfloat16 if fp16 else jnp.float32
- hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
- return hidden_states
-
- @parameterized.expand(
- [
- # fmt: off
- [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
- [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
- [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
- [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
- # fmt: on
- ]
- )
- def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
- model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
- latents = self.get_latents(seed, fp16=True)
- encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
-
- sample = model.apply(
- {"params": params},
- latents,
- jnp.array(timestep, dtype=jnp.int32),
- encoder_hidden_states=encoder_hidden_states,
- ).sample
-
- assert sample.shape == latents.shape
-
- output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
- expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
-
- # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
- assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
-
- @parameterized.expand(
- [
- # fmt: off
- [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
- [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
- [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
- [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
- # fmt: on
- ]
- )
- def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
- model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
- latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
- encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
-
- sample = model.apply(
- {"params": params},
- latents,
- jnp.array(timestep, dtype=jnp.int32),
- encoder_hidden_states=encoder_hidden_states,
- ).sample
-
- assert sample.shape == latents.shape
-
- output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
- expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
-
- # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
- assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
diff --git a/tests/models/unets/test_models_unet_3d_condition.py b/tests/models/unets/test_models_unet_3d_condition.py
index e798586b6965..f73e3461c38e 100644
--- a/tests/models/unets/test_models_unet_3d_condition.py
+++ b/tests/models/unets/test_models_unet_3d_condition.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,8 +21,8 @@
from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py
index 9431e810280f..40773536df70 100644
--- a/tests/models/unets/test_models_unet_controlnetxs.py
+++ b/tests/models/unets/test_models_unet_controlnetxs.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,8 +21,8 @@
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from diffusers.utils import logging
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py
index 209806a5fe26..d931b345fd09 100644
--- a/tests/models/unets/test_models_unet_motion.py
+++ b/tests/models/unets/test_models_unet_motion.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,12 +24,12 @@
from diffusers import MotionAdapter, UNet2DConditionModel, UNetMotionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py
index 0d7dc823b026..7df868c9e95b 100644
--- a/tests/models/unets/test_models_unet_spatiotemporal.py
+++ b/tests/models/unets/test_models_unet_spatiotemporal.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,13 +21,13 @@
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_unet_2d_blocks.py b/tests/models/unets/test_unet_2d_blocks.py
index e37199170214..5c006963e30c 100644
--- a/tests/models/unets/test_unet_2d_blocks.py
+++ b/tests/models/unets/test_unet_2d_blocks.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +15,8 @@
import unittest
from diffusers.models.unets.unet_2d_blocks import * # noqa F403
-from diffusers.utils.testing_utils import torch_device
+from ...testing_utils import torch_device
from .test_unet_blocks_common import UNetBlockTesterMixin
diff --git a/tests/models/unets/test_unet_blocks_common.py b/tests/models/unets/test_unet_blocks_common.py
index dce28c77cbb7..85f9bf8353bf 100644
--- a/tests/models/unets/test_unet_blocks_common.py
+++ b/tests/models/unets/test_unet_blocks_common.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,14 +16,15 @@
import torch
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
floats_tensor,
require_torch,
require_torch_accelerator_with_training,
torch_all_close,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
@require_torch
diff --git a/tests/pipelines/controlnet_xs/__init__.py b/tests/modular_pipelines/__init__.py
similarity index 100%
rename from tests/pipelines/controlnet_xs/__init__.py
rename to tests/modular_pipelines/__init__.py
diff --git a/tests/pipelines/dance_diffusion/__init__.py b/tests/modular_pipelines/flux/__init__.py
similarity index 100%
rename from tests/pipelines/dance_diffusion/__init__.py
rename to tests/modular_pipelines/flux/__init__.py
diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py
new file mode 100644
index 000000000000..854b5218c617
--- /dev/null
+++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import tempfile
+
+import numpy as np
+import PIL
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.modular_pipelines import (
+ FluxAutoBlocks,
+ FluxKontextAutoBlocks,
+ FluxKontextModularPipeline,
+ FluxModularPipeline,
+ ModularPipeline,
+)
+
+from ...testing_utils import floats_tensor, torch_device
+from ..test_modular_pipelines_common import ModularPipelineTesterMixin
+
+
+class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
+ pipeline_class = FluxModularPipeline
+ pipeline_blocks_class = FluxAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
+
+ params = frozenset(["prompt", "height", "width", "guidance_scale"])
+ batch_params = frozenset(["prompt"])
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_float16_inference(self):
+ super().test_float16_inference(9e-2)
+
+
+class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
+ pipeline_class = FluxModularPipeline
+ pipeline_blocks_class = FluxAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
+
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
+ batch_params = frozenset(["prompt", "image"])
+
+ def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
+ pipeline = super().get_pipeline(components_manager, torch_dtype)
+
+ # Override `vae_scale_factor` here as currently, `image_processor` is initialized with
+ # fixed constants instead of
+ # https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
+ pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
+ return pipeline
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 4,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "pt",
+ }
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
+
+ inputs["image"] = init_image
+ inputs["strength"] = 0.5
+
+ return inputs
+
+ def test_save_from_pretrained(self):
+ pipes = []
+ base_pipe = self.get_pipeline().to(torch_device)
+ pipes.append(base_pipe)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ base_pipe.save_pretrained(tmpdirname)
+
+ pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
+ pipe.load_components(torch_dtype=torch.float32)
+ pipe.to(torch_device)
+ pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
+
+ pipes.append(pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ inputs = self.get_dummy_inputs()
+ image = pipe(**inputs, output="images")
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+
+ def test_float16_inference(self):
+ super().test_float16_inference(8e-2)
+
+
+class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
+ pipeline_class = FluxKontextModularPipeline
+ pipeline_blocks_class = FluxKontextAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
+
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
+ batch_params = frozenset(["prompt", "image"])
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "pt",
+ }
+ image = PIL.Image.new("RGB", (32, 32), 0)
+
+ inputs["image"] = image
+ inputs["max_area"] = inputs["height"] * inputs["width"]
+ inputs["_auto_resize"] = False
+
+ return inputs
+
+ def test_save_from_pretrained(self):
+ pipes = []
+ base_pipe = self.get_pipeline().to(torch_device)
+ pipes.append(base_pipe)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ base_pipe.save_pretrained(tmpdirname)
+
+ pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
+ pipe.load_components(torch_dtype=torch.float32)
+ pipe.to(torch_device)
+ pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
+
+ pipes.append(pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ inputs = self.get_dummy_inputs()
+ image = pipe(**inputs, output="images")
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+
+ def test_float16_inference(self):
+ super().test_float16_inference(9e-2)
diff --git a/tests/pipelines/i2vgen_xl/__init__.py b/tests/modular_pipelines/qwen/__init__.py
similarity index 100%
rename from tests/pipelines/i2vgen_xl/__init__.py
rename to tests/modular_pipelines/qwen/__init__.py
diff --git a/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py
new file mode 100644
index 000000000000..8d7600781b24
--- /dev/null
+++ b/tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py
@@ -0,0 +1,120 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import PIL
+import pytest
+
+from diffusers.modular_pipelines import (
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditModularPipeline,
+ QwenImageEditPlusAutoBlocks,
+ QwenImageEditPlusModularPipeline,
+ QwenImageModularPipeline,
+)
+
+from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
+
+
+class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
+ pipeline_class = QwenImageModularPipeline
+ pipeline_blocks_class = QwenImageAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
+
+ params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
+ batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
+
+ def get_dummy_inputs(self):
+ generator = self.get_generator()
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=5e-4)
+
+
+class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
+ pipeline_class = QwenImageEditModularPipeline
+ pipeline_blocks_class = QwenImageEditAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
+
+ params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
+ batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
+
+ def get_dummy_inputs(self):
+ generator = self.get_generator()
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "height": 32,
+ "width": 32,
+ "output_type": "pt",
+ }
+ inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
+ return inputs
+
+ def test_guider_cfg(self):
+ super().test_guider_cfg(7e-5)
+
+
+class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
+ pipeline_class = QwenImageEditPlusModularPipeline
+ pipeline_blocks_class = QwenImageEditPlusAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
+
+ # No `mask_image` yet.
+ params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
+ batch_params = frozenset(["prompt", "negative_prompt", "image"])
+
+ def get_dummy_inputs(self):
+ generator = self.get_generator()
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "height": 32,
+ "width": 32,
+ "output_type": "pt",
+ }
+ inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
+ return inputs
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_num_images_per_prompt(self):
+ super().test_num_images_per_prompt()
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_inference_batch_consistent():
+ super().test_inference_batch_consistent()
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_inference_batch_single_identical():
+ super().test_inference_batch_single_identical()
+
+ def test_guider_cfg(self):
+ super().test_guider_cfg(1e-3)
diff --git a/tests/pipelines/musicldm/__init__.py b/tests/modular_pipelines/stable_diffusion_xl/__init__.py
similarity index 100%
rename from tests/pipelines/musicldm/__init__.py
rename to tests/modular_pipelines/stable_diffusion_xl/__init__.py
diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py
new file mode 100644
index 000000000000..7b55933e4caf
--- /dev/null
+++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py
@@ -0,0 +1,439 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from typing import Any, Dict
+
+import numpy as np
+import torch
+from PIL import Image
+
+from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
+from diffusers.loaders import ModularIPAdapterMixin
+
+from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SDXLModularTesterMixin:
+ """
+ This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
+ """
+
+ def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
+ sd_pipe = self.get_pipeline().to(torch_device)
+
+ inputs = self.get_dummy_inputs()
+ image = sd_pipe(**inputs, output="images")
+ image_slice = image[0, -3:, -3:, -1].cpu()
+
+ assert image.shape == expected_image_shape
+ max_diff = torch.abs(image_slice.flatten() - expected_slice).max()
+ assert max_diff < expected_max_diff, f"Image slice does not match expected slice. Max Difference: {max_diff}"
+
+
+class SDXLModularIPAdapterTesterMixin:
+ """
+ This mixin is designed to test IP Adapter.
+ """
+
+ def test_pipeline_inputs_and_blocks(self):
+ blocks = self.pipeline_blocks_class()
+ parameters = blocks.input_names
+
+ assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
+ assert "ip_adapter_image" in parameters, (
+ "`ip_adapter_image` argument must be supported by the `__call__` method"
+ )
+ assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
+
+ _ = blocks.sub_blocks.pop("ip_adapter")
+ parameters = blocks.input_names
+ assert "ip_adapter_image" not in parameters, (
+ "`ip_adapter_image` argument must be removed from the `__call__` method"
+ )
+
+ def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
+ return torch.randn((1, 1, cross_attention_dim), device=torch_device)
+
+ def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
+ return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
+
+ def _get_dummy_masks(self, input_size: int = 64):
+ _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
+ _masks[0, :, :, : int(input_size / 2)] = 1
+ return _masks
+
+ def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
+ blocks = self.pipeline_blocks_class()
+ _ = blocks.sub_blocks.pop("ip_adapter")
+ parameters = blocks.input_names
+ if "image" in parameters and "strength" in parameters:
+ inputs["num_inference_steps"] = 4
+
+ inputs["output_type"] = "pt"
+ return inputs
+
+ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
+ r"""Tests for IP-Adapter.
+
+ The following scenarios are tested:
+ - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
+ - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
+ - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
+ - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
+ """
+ # Raising the tolerance for this test when it's run on a CPU because we
+ # compare against static slices and that can be shaky (with a VVVV low probability).
+ expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
+
+ blocks = self.pipeline_blocks_class()
+ _ = blocks.sub_blocks.pop("ip_adapter")
+ pipe = blocks.init_pipeline(self.pretrained_model_name_or_path)
+ pipe.load_components(torch_dtype=torch.float32)
+ pipe = pipe.to(torch_device)
+
+ cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
+
+ # forward pass without ip adapter
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
+ if expected_pipe_slice is None:
+ output_without_adapter = pipe(**inputs, output="images")
+ else:
+ output_without_adapter = expected_pipe_slice
+
+ # 1. Single IP-Adapter test cases
+ adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
+ pipe.unet._load_ip_adapter_weights(adapter_state_dict)
+
+ # forward pass with single ip adapter, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
+ inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ pipe.set_ip_adapter_scale(0.0)
+ output_without_adapter_scale = pipe(**inputs, output="images")
+ if expected_pipe_slice is not None:
+ output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single ip adapter, but with scale of adapter weights
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
+ inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ pipe.set_ip_adapter_scale(42.0)
+ output_with_adapter_scale = pipe(**inputs, output="images")
+ if expected_pipe_slice is not None:
+ output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_adapter_scale = torch.abs(output_without_adapter_scale - output_without_adapter).max()
+ max_diff_with_adapter_scale = torch.abs(output_with_adapter_scale - output_without_adapter).max()
+
+ assert max_diff_without_adapter_scale < expected_max_diff, (
+ "Output without ip-adapter must be same as normal inference"
+ )
+ assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
+
+ # 2. Multi IP-Adapter test cases
+ adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
+ adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
+ pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
+
+ # forward pass with multi ip adapter, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
+ inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
+ inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
+ pipe.set_ip_adapter_scale([0.0, 0.0])
+ output_without_multi_adapter_scale = pipe(**inputs, output="images")
+ if expected_pipe_slice is not None:
+ output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with multi ip adapter, but with scale of adapter weights
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
+ inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
+ inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
+ pipe.set_ip_adapter_scale([42.0, 42.0])
+ output_with_multi_adapter_scale = pipe(**inputs, output="images")
+ if expected_pipe_slice is not None:
+ output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_multi_adapter_scale = torch.abs(
+ output_without_multi_adapter_scale - output_without_adapter
+ ).max()
+ max_diff_with_multi_adapter_scale = torch.abs(output_with_multi_adapter_scale - output_without_adapter).max()
+ assert max_diff_without_multi_adapter_scale < expected_max_diff, (
+ "Output without multi-ip-adapter must be same as normal inference"
+ )
+ assert max_diff_with_multi_adapter_scale > 1e-2, (
+ "Output with multi-ip-adapter scale must be different from normal inference"
+ )
+
+
+class SDXLModularControlNetTesterMixin:
+ """
+ This mixin is designed to test ControlNet.
+ """
+
+ def test_pipeline_inputs(self):
+ blocks = self.pipeline_blocks_class()
+ parameters = blocks.input_names
+
+ assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
+ assert "controlnet_conditioning_scale" in parameters, (
+ "`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
+ )
+
+ def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
+ controlnet_embedder_scale_factor = 2
+ image = torch.randn(
+ (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
+ device=torch_device,
+ )
+ inputs["control_image"] = image
+ return inputs
+
+ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
+ r"""Tests for ControlNet.
+
+ The following scenarios are tested:
+ - Single ControlNet with scale=0 should produce same output as no ControlNet.
+ - Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
+ """
+ # Raising the tolerance for this test when it's run on a CPU because we
+ # compare against static slices and that can be shaky (with a VVVV low probability).
+ expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
+
+ pipe = self.get_pipeline().to(torch_device)
+
+ # forward pass without controlnet
+ inputs = self.get_dummy_inputs()
+ output_without_controlnet = pipe(**inputs, output="images")
+ output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single controlnet, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
+ inputs["controlnet_conditioning_scale"] = 0.0
+ output_without_controlnet_scale = pipe(**inputs, output="images")
+ output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single controlnet, but with scale of adapter weights
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
+ inputs["controlnet_conditioning_scale"] = 42.0
+ output_with_controlnet_scale = pipe(**inputs, output="images")
+ output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_controlnet_scale = torch.abs(
+ output_without_controlnet_scale - output_without_controlnet
+ ).max()
+ max_diff_with_controlnet_scale = torch.abs(output_with_controlnet_scale - output_without_controlnet).max()
+
+ assert max_diff_without_controlnet_scale < expected_max_diff, (
+ "Output without controlnet must be same as normal inference"
+ )
+ assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
+
+ def test_controlnet_cfg(self):
+ pipe = self.get_pipeline().to(torch_device)
+
+ # forward pass with CFG not applied
+ guider = ClassifierFreeGuidance(guidance_scale=1.0)
+ pipe.update_components(guider=guider)
+
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
+ out_no_cfg = pipe(**inputs, output="images")
+
+ # forward pass with CFG applied
+ guider = ClassifierFreeGuidance(guidance_scale=7.5)
+ pipe.update_components(guider=guider)
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
+ out_cfg = pipe(**inputs, output="images")
+
+ assert out_cfg.shape == out_no_cfg.shape
+ max_diff = torch.abs(out_cfg - out_no_cfg).max()
+ assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
+
+
+class TestSDXLModularPipelineFast(
+ SDXLModularTesterMixin,
+ SDXLModularIPAdapterTesterMixin,
+ SDXLModularControlNetTesterMixin,
+ ModularGuiderTesterMixin,
+ ModularPipelineTesterMixin,
+):
+ """Test cases for Stable Diffusion XL modular pipeline fast tests."""
+
+ pipeline_class = StableDiffusionXLModularPipeline
+ pipeline_blocks_class = StableDiffusionXLAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "negative_prompt",
+ "cross_attention_kwargs",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt"])
+ expected_image_output_shape = (1, 3, 64, 64)
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_stable_diffusion_xl_euler(self):
+ self._test_stable_diffusion_xl_euler(
+ expected_image_shape=self.expected_image_output_shape,
+ expected_slice=torch.tensor(
+ [0.3886, 0.4685, 0.4953, 0.4217, 0.4317, 0.3945, 0.4847, 0.4704, 0.4731],
+ ),
+ expected_max_diff=1e-2,
+ )
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+
+
+class TestSDXLImg2ImgModularPipelineFast(
+ SDXLModularTesterMixin,
+ SDXLModularIPAdapterTesterMixin,
+ SDXLModularControlNetTesterMixin,
+ ModularGuiderTesterMixin,
+ ModularPipelineTesterMixin,
+):
+ """Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
+
+ pipeline_class = StableDiffusionXLModularPipeline
+ pipeline_blocks_class = StableDiffusionXLAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "negative_prompt",
+ "cross_attention_kwargs",
+ "image",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt", "image"])
+ expected_image_output_shape = (1, 3, 64, 64)
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 4,
+ "output_type": "pt",
+ }
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
+
+ inputs["image"] = init_image
+ inputs["strength"] = 0.5
+
+ return inputs
+
+ def test_stable_diffusion_xl_euler(self):
+ self._test_stable_diffusion_xl_euler(
+ expected_image_shape=self.expected_image_output_shape,
+ expected_slice=torch.tensor([0.5246, 0.4466, 0.444, 0.3246, 0.4443, 0.5108, 0.5225, 0.559, 0.5147]),
+ expected_max_diff=1e-2,
+ )
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+
+
+class SDXLInpaintingModularPipelineFastTests(
+ SDXLModularTesterMixin,
+ SDXLModularIPAdapterTesterMixin,
+ SDXLModularControlNetTesterMixin,
+ ModularGuiderTesterMixin,
+ ModularPipelineTesterMixin,
+):
+ """Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
+
+ pipeline_class = StableDiffusionXLModularPipeline
+ pipeline_blocks_class = StableDiffusionXLAutoBlocks
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "negative_prompt",
+ "cross_attention_kwargs",
+ "image",
+ "mask_image",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
+ expected_image_output_shape = (1, 3, 64, 64)
+
+ def get_dummy_inputs(self, device, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 4,
+ "output_type": "pt",
+ }
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
+
+ # create mask
+ image[8:, 8:, :] = 255
+ mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
+
+ inputs["image"] = init_image
+ inputs["mask_image"] = mask_image
+ inputs["strength"] = 1.0
+
+ return inputs
+
+ def test_stable_diffusion_xl_euler(self):
+ self._test_stable_diffusion_xl_euler(
+ expected_image_shape=self.expected_image_output_shape,
+ expected_slice=torch.tensor(
+ [
+ 0.40872607,
+ 0.38842705,
+ 0.34893104,
+ 0.47837183,
+ 0.43792963,
+ 0.5332134,
+ 0.3716843,
+ 0.47274873,
+ 0.45000193,
+ ],
+ device=torch_device,
+ ),
+ expected_max_diff=1e-2,
+ )
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-3)
diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py
new file mode 100644
index 000000000000..a33951dac538
--- /dev/null
+++ b/tests/modular_pipelines/test_modular_pipelines_common.py
@@ -0,0 +1,338 @@
+import gc
+import tempfile
+from typing import Callable, Union
+
+import pytest
+import torch
+
+import diffusers
+from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
+from diffusers.guiders import ClassifierFreeGuidance
+from diffusers.utils import logging
+
+from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
+
+
+class ModularPipelineTesterMixin:
+ """
+ It provides a set of common tests for each modular pipeline,
+ including:
+ - test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
+ - test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
+ - test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
+ - test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
+ - test_to_device: check if the pipeline's __call__ method can handle different devices
+ """
+
+ # Canonical parameters that are passed to `__call__` regardless
+ # of the type of pipeline. They are always optional and have common
+ # sense default values.
+ optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
+ # this is modular specific: generator needs to be a intermediate input because it's mutable
+ intermediate_params = frozenset(["generator"])
+
+ def get_generator(self, seed=0):
+ generator = torch.Generator("cpu").manual_seed(seed)
+ return generator
+
+ @property
+ def pipeline_class(self) -> Union[Callable, ModularPipeline]:
+ raise NotImplementedError(
+ "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def pretrained_model_name_or_path(self) -> str:
+ raise NotImplementedError(
+ "You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
+ )
+
+ @property
+ def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
+ raise NotImplementedError(
+ "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ def get_dummy_inputs(self, seed=0):
+ raise NotImplementedError(
+ "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `params` in the child test class. "
+ "`params` are checked for if all values are present in `__call__`'s signature."
+ " You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
+ " e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
+ "image pipelines, including prompts and prompt embedding overrides."
+ "If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
+ "do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
+ "with non-configurable height and width arguments should set the attribute as "
+ "`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def batch_params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `batch_params` in the child test class. "
+ "`batch_params` are the parameters required to be batched when passed to the pipeline's "
+ "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
+ "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
+ "set of batch arguments has minor changes from one of the common sets of batch arguments, "
+ "do not make modifications to the existing common sets of batch arguments. I.e. a text to "
+ "image pipeline `negative_prompt` is not batched should set the attribute as "
+ "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
+ "See existing pipeline tests for reference."
+ )
+
+ def setup_method(self):
+ # clean up the VRAM before each test
+ torch.compiler.reset()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def teardown_method(self):
+ # clean up the VRAM after each test in case of CUDA runtime errors
+ torch.compiler.reset()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
+ pipeline = self.pipeline_blocks_class().init_pipeline(
+ self.pretrained_model_name_or_path, components_manager=components_manager
+ )
+ pipeline.load_components(torch_dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=None)
+ return pipeline
+
+ def test_pipeline_call_signature(self):
+ pipe = self.get_pipeline()
+ input_parameters = pipe.blocks.input_names
+ optional_parameters = pipe.default_call_parameters
+
+ def _check_for_parameters(parameters, expected_parameters, param_type):
+ remaining_parameters = {param for param in parameters if param not in expected_parameters}
+ assert len(remaining_parameters) == 0, (
+ f"Required {param_type} parameters not present: {remaining_parameters}"
+ )
+
+ _check_for_parameters(self.params, input_parameters, "input")
+ _check_for_parameters(self.optional_params, optional_parameters, "optional")
+
+ def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
+ pipe = self.get_pipeline().to(torch_device)
+
+ inputs = self.get_dummy_inputs()
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # prepare batched inputs
+ batched_inputs = []
+ for batch_size in batch_sizes:
+ batched_input = {}
+ batched_input.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ batched_input[name] = batch_size * [value]
+
+ if batch_generator and "generator" in inputs:
+ batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_input["batch_size"] = batch_size
+
+ batched_inputs.append(batched_input)
+
+ logger.setLevel(level=diffusers.logging.WARNING)
+ for batch_size, batched_input in zip(batch_sizes, batched_inputs):
+ output = pipe(**batched_input, output="images")
+ assert len(output) == batch_size, "Output is different from expected batch size"
+
+ def test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ ):
+ pipe = self.get_pipeline().to(torch_device)
+
+ inputs = self.get_dummy_inputs()
+
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ output = pipe(**inputs, output="images")
+ output_batch = pipe(**batched_inputs, output="images")
+
+ assert output_batch.shape[0] == batch_size
+
+ max_diff = torch.abs(output_batch[0] - output[0]).max()
+ assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
+
+ @require_accelerator
+ def test_float16_inference(self, expected_max_diff=5e-2):
+ pipe = self.get_pipeline()
+ pipe.to(torch_device, torch.float32)
+
+ pipe_fp16 = self.get_pipeline()
+ pipe_fp16.to(torch_device, torch.float16)
+
+ inputs = self.get_dummy_inputs()
+ # Reset generator in case it is used inside dummy inputs
+ if "generator" in inputs:
+ inputs["generator"] = self.get_generator(0)
+ output = pipe(**inputs, output="images")
+
+ fp16_inputs = self.get_dummy_inputs()
+ # Reset generator in case it is used inside dummy inputs
+ if "generator" in fp16_inputs:
+ fp16_inputs["generator"] = self.get_generator(0)
+ output_fp16 = pipe_fp16(**fp16_inputs, output="images")
+
+ output = output.cpu()
+ output_fp16 = output_fp16.cpu()
+
+ max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
+ assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
+
+ @require_accelerator
+ def test_to_device(self):
+ pipe = self.get_pipeline().to("cpu")
+
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
+
+ pipe.to(torch_device)
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ assert all(device == torch_device for device in model_devices), (
+ "All pipeline components are not on accelerator device"
+ )
+
+ def test_inference_is_not_nan_cpu(self):
+ pipe = self.get_pipeline().to("cpu")
+
+ output = pipe(**self.get_dummy_inputs(), output="images")
+ assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
+
+ @require_accelerator
+ def test_inference_is_not_nan(self):
+ pipe = self.get_pipeline().to(torch_device)
+
+ output = pipe(**self.get_dummy_inputs(), output="images")
+ assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
+
+ def test_num_images_per_prompt(self):
+ pipe = self.get_pipeline().to(torch_device)
+
+ if "num_images_per_prompt" not in pipe.blocks.input_names:
+ pytest.mark.skip("Skipping test as `num_images_per_prompt` is not present in input names.")
+
+ batch_sizes = [1, 2]
+ num_images_per_prompts = [1, 2]
+
+ for batch_size in batch_sizes:
+ for num_images_per_prompt in num_images_per_prompts:
+ inputs = self.get_dummy_inputs()
+
+ for key in inputs.keys():
+ if key in self.batch_params:
+ inputs[key] = batch_size * [inputs[key]]
+
+ images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
+
+ assert images.shape[0] == batch_size * num_images_per_prompt
+
+ @require_accelerator
+ def test_components_auto_cpu_offload_inference_consistent(self):
+ base_pipe = self.get_pipeline().to(torch_device)
+
+ cm = ComponentsManager()
+ cm.enable_auto_cpu_offload(device=torch_device)
+ offload_pipe = self.get_pipeline(components_manager=cm)
+
+ image_slices = []
+ for pipe in [base_pipe, offload_pipe]:
+ inputs = self.get_dummy_inputs()
+ image = pipe(**inputs, output="images")
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+
+ def test_save_from_pretrained(self):
+ pipes = []
+ base_pipe = self.get_pipeline().to(torch_device)
+ pipes.append(base_pipe)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ base_pipe.save_pretrained(tmpdirname)
+ pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
+ pipe.load_components(torch_dtype=torch.float32)
+ pipe.to(torch_device)
+
+ pipes.append(pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ inputs = self.get_dummy_inputs()
+ image = pipe(**inputs, output="images")
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+
+
+class ModularGuiderTesterMixin:
+ def test_guider_cfg(self, expected_max_diff=1e-2):
+ pipe = self.get_pipeline().to(torch_device)
+
+ # forward pass with CFG not applied
+ guider = ClassifierFreeGuidance(guidance_scale=1.0)
+ pipe.update_components(guider=guider)
+
+ inputs = self.get_dummy_inputs()
+ out_no_cfg = pipe(**inputs, output="images")
+
+ # forward pass with CFG applied
+ guider = ClassifierFreeGuidance(guidance_scale=7.5)
+ pipe.update_components(guider=guider)
+ inputs = self.get_dummy_inputs()
+ out_cfg = pipe(**inputs, output="images")
+
+ assert out_cfg.shape == out_no_cfg.shape
+ max_diff = torch.abs(out_cfg - out_no_cfg).max()
+ assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
diff --git a/tests/pipelines/paint_by_example/__init__.py b/tests/others/__init__.py
similarity index 100%
rename from tests/pipelines/paint_by_example/__init__.py
rename to tests/others/__init__.py
diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py
new file mode 100644
index 000000000000..01f4521c5adc
--- /dev/null
+++ b/tests/others/test_attention_backends.py
@@ -0,0 +1,163 @@
+"""
+This test suite exists for the maintainers currently. It's not run in our CI at the moment.
+
+Once attention backends become more mature, we can consider including this in our CI.
+
+To run this test suite:
+
+```bash
+export RUN_ATTENTION_BACKEND_TESTS=yes
+
+pytest tests/others/test_attention_backends.py
+```
+
+Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
+"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
+
+Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
+with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
+aiter 0.1.5.post4.dev20+ga25e55e79.
+"""
+
+import os
+
+import pytest
+import torch
+
+
+pytestmark = pytest.mark.skipif(
+ os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
+)
+from diffusers import FluxPipeline # noqa: E402
+from diffusers.utils import is_torch_version # noqa: E402
+
+
+# fmt: off
+FORWARD_CASES = [
+ (
+ "flash_hub",
+ torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
+ ),
+ (
+ "_flash_3_hub",
+ torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
+ ),
+ (
+ "native",
+ torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
+ ),
+ (
+ "_native_cudnn",
+ torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
+ ),
+ (
+ "aiter",
+ torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
+ )
+]
+
+COMPILE_CASES = [
+ (
+ "flash_hub",
+ torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
+ True
+ ),
+ (
+ "_flash_3_hub",
+ torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
+ True,
+ ),
+ (
+ "native",
+ torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
+ True,
+ ),
+ (
+ "_native_cudnn",
+ torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
+ True,
+ ),
+ (
+ "aiter",
+ torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
+ True,
+ )
+]
+# fmt: on
+
+INFER_KW = {
+ "prompt": "dance doggo dance",
+ "height": 256,
+ "width": 256,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.5,
+ "max_sequence_length": 128,
+ "output_type": "pt",
+}
+
+
+def _backend_is_probably_supported(pipe, name: str):
+ try:
+ pipe.transformer.set_attention_backend(name)
+ return pipe, True
+ except Exception:
+ return False
+
+
+def _check_if_slices_match(output, expected_slice):
+ img = output.images.detach().cpu()
+ generated_slice = img.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
+
+
+@pytest.fixture(scope="session")
+def device():
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA is required for these tests.")
+ return torch.device("cuda:0")
+
+
+@pytest.fixture(scope="session")
+def pipe(device):
+ repo_id = "black-forest-labs/FLUX.1-dev"
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
+ pipe.set_progress_bar_config(disable=True)
+ return pipe
+
+
+@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
+def test_forward(pipe, backend_name, expected_slice):
+ out = _backend_is_probably_supported(pipe, backend_name)
+ if isinstance(out, bool):
+ pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
+
+ modified_pipe = out[0]
+ out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
+ _check_if_slices_match(out, expected_slice)
+
+
+@pytest.mark.parametrize(
+ "backend_name,expected_slice,error_on_recompile",
+ COMPILE_CASES,
+ ids=[c[0] for c in COMPILE_CASES],
+)
+def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
+ if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
+ pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
+
+ out = _backend_is_probably_supported(pipe, backend_name)
+ if isinstance(out, bool):
+ pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
+
+ modified_pipe = out[0]
+ modified_pipe.transformer.compile(fullgraph=True)
+
+ torch.compiler.reset()
+ with (
+ torch._inductor.utils.fresh_inductor_cache(),
+ torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
+ ):
+ out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
+
+ _check_if_slices_match(out, expected_slice)
diff --git a/tests/others/test_check_copies.py b/tests/others/test_check_copies.py
index 5835712343c7..4b6fa28eb9ac 100644
--- a/tests/others/test_check_copies.py
+++ b/tests/others/test_check_copies.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/others/test_check_dummies.py b/tests/others/test_check_dummies.py
index 1890ffaecd8d..b7c544370ca8 100644
--- a/tests/others/test_check_dummies.py
+++ b/tests/others/test_check_dummies.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/others/test_config.py b/tests/others/test_config.py
index 664c36ac33d6..232bf9d473b8 100644
--- a/tests/others/test_config.py
+++ b/tests/others/test_config.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,7 +28,8 @@
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.utils.testing_utils import CaptureLogger
+
+from ..testing_utils import CaptureLogger
class SampleObject(ConfigMixin):
diff --git a/tests/others/test_dependencies.py b/tests/others/test_dependencies.py
index c0839ef0236b..db22f10c4b3c 100644
--- a/tests/others/test_dependencies.py
+++ b/tests/others/test_dependencies.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,6 +37,10 @@ def test_backend_registration(self):
backend = "k-diffusion"
elif backend == "invisible_watermark":
backend = "invisible-watermark"
+ elif backend == "opencv":
+ backend = "opencv-python"
+ elif backend == "nvidia_modelopt":
+ backend = "nvidia_modelopt[hf]"
assert backend in deps, f"{backend} is not in the deps table!"
def test_pipeline_imports(self):
diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py
index 7cf8f30ecc44..436bbe1d53ff 100644
--- a/tests/others/test_ema.py
+++ b/tests/others/test_ema.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,8 @@
from diffusers import UNet2DConditionModel
from diffusers.training_utils import EMAModel
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
+
+from ..testing_utils import enable_full_determinism, skip_mps, torch_device
enable_full_determinism()
diff --git a/tests/others/test_hub_utils.py b/tests/others/test_hub_utils.py
index 7a0c29dcb0f3..0a6b8ef2bd9f 100644
--- a/tests/others/test_hub_utils.py
+++ b/tests/others/test_hub_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py
index 3397ca9e394a..e9e5c0670676 100644
--- a/tests/others/test_image_processor.py
+++ b/tests/others/test_image_processor.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -65,9 +65,9 @@ def test_vae_image_processor_pt(self):
)
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -78,9 +78,9 @@ def test_vae_image_processor_np(self):
out_np = self.to_np(out)
in_np = (input_np * 255).round() if output_type == "pil" else input_np
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
@@ -93,9 +93,9 @@ def test_vae_image_processor_pil(self):
for i, o in zip(input_pil, out):
in_np = np.array(i)
out_np = self.to_np(out) if output_type == "pil" else (self.to_np(out) * 255).round()
- assert (
- np.abs(in_np - out_np).max() < 1e-6
- ), f"decoded output does not match input for output_type {output_type}"
+ assert np.abs(in_np - out_np).max() < 1e-6, (
+ f"decoded output does not match input for output_type {output_type}"
+ )
def test_preprocess_input_3d(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
@@ -293,9 +293,9 @@ def test_vae_image_processor_resize_pt(self):
scale = 2
out_pt = image_processor.resize(image=input_pt, height=h // scale, width=w // scale)
exp_pt_shape = (b, c, h // scale, w // scale)
- assert (
- out_pt.shape == exp_pt_shape
- ), f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
+ assert out_pt.shape == exp_pt_shape, (
+ f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'."
+ )
def test_vae_image_processor_resize_np(self):
image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1)
@@ -305,6 +305,6 @@ def test_vae_image_processor_resize_np(self):
input_np = self.to_np(input_pt)
out_np = image_processor.resize(image=input_np, height=h // scale, width=w // scale)
exp_np_shape = (b, h // scale, w // scale, c)
- assert (
- out_np.shape == exp_np_shape
- ), f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
+ assert out_np.shape == exp_np_shape, (
+ f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'."
+ )
diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py
index cf709d93f709..c8069e6916ed 100644
--- a/tests/others/test_outputs.py
+++ b/tests/others/test_outputs.py
@@ -7,7 +7,8 @@
import PIL.Image
from diffusers.utils.outputs import BaseOutput
-from diffusers.utils.testing_utils import require_torch
+
+from ..testing_utils import require_torch
@dataclass
diff --git a/tests/others/test_training.py b/tests/others/test_training.py
index 863ba6e10798..2038a98a813e 100644
--- a/tests/others/test_training.py
+++ b/tests/others/test_training.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +19,8 @@
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from diffusers.training_utils import set_seed
-from diffusers.utils.testing_utils import slow
+
+from ..testing_utils import slow
torch.backends.cuda.matmul.allow_tf32 = False
diff --git a/tests/others/test_utils.py b/tests/others/test_utils.py
index 65c0b3ece4d2..747b8d584058 100755
--- a/tests/others/test_utils.py
+++ b/tests/others/test_utils.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,7 +20,8 @@
from diffusers import __version__
from diffusers.utils import deprecate
-from diffusers.utils.testing_utils import str_to_bool
+
+from ..testing_utils import Expectations, str_to_bool
# Used to test the hub
@@ -182,6 +183,38 @@ def test_deprecate_stacklevel(self):
assert "diffusers/tests/others/test_utils.py" in warning.filename
+# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
+class ExpectationsTester(unittest.TestCase):
+ def test_expectations(self):
+ expectations = Expectations(
+ {
+ (None, None): 1,
+ ("cuda", 8): 2,
+ ("cuda", 7): 3,
+ ("rocm", 8): 4,
+ ("rocm", None): 5,
+ ("cpu", None): 6,
+ ("xpu", 3): 7,
+ }
+ )
+
+ def check(value, key):
+ assert expectations.find_expectation(key) == value
+
+ # npu has no matches so should find default expectation
+ check(1, ("npu", None))
+ check(7, ("xpu", 3))
+ check(2, ("cuda", 8))
+ check(3, ("cuda", 7))
+ check(4, ("rocm", 9))
+ check(4, ("rocm", None))
+ check(2, ("cuda", 2))
+
+ expectations = Expectations({("cuda", 8): 1})
+ with self.assertRaises(ValueError):
+ expectations.find_expectation(("xpu", None))
+
+
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
diff --git a/tests/others/test_video_processor.py b/tests/others/test_video_processor.py
index a2fc87717f9d..35c9f99c37ae 100644
--- a/tests/others/test_video_processor.py
+++ b/tests/others/test_video_processor.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py
index 30fdd68cfd36..b2e588de0647 100644
--- a/tests/pipelines/allegro/test_allegro.py
+++ b/tests/pipelines/allegro/test_allegro.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,7 +23,9 @@
from transformers import AutoTokenizer, T5Config, T5EncoderModel
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_hf_hub_version_greater,
@@ -32,7 +34,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
@@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_allegro(self):
generator = torch.Generator("cpu").manual_seed(0)
diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py
deleted file mode 100644
index a0fbc5df1c28..000000000000
--- a/tests/pipelines/amused/test_amused.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedPipeline
- params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=8,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=8,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "height": 4,
- "width": 4,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.4011, 0.3992, 0.379, 0.3856, 0.3772, 0.3711, 0.3919, 0.385, 0.3625])
- assert np.abs(image_slice - expected_slice).max() < 0.003
-
- def test_amused_256_fp16(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-256", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0554, 0.05129, 0.0344, 0.0452, 0.0476, 0.0271, 0.0495, 0.0527, 0.0158])
- assert np.abs(image_slice - expected_slice).max() < 0.007
-
- def test_amused_512(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1199, 0.1171, 0.1229, 0.1188, 0.1210, 0.1147, 0.1260, 0.1346, 0.1152])
- assert np.abs(image_slice - expected_slice).max() < 0.003
-
- def test_amused_512_fp16(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1509, 0.1492, 0.1531, 0.1485, 0.1501, 0.1465, 0.1581, 0.1690, 0.1499])
- assert np.abs(image_slice - expected_slice).max() < 0.003
diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py
deleted file mode 100644
index 2699bbe7f56f..000000000000
--- a/tests/pipelines/amused/test_amused_img2img.py
+++ /dev/null
@@ -1,216 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedImg2ImgPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedImg2ImgPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "latents"}
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=8,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=32,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "image": image,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.9993, 1.0, 0.9996, 1.0, 0.9995, 0.9925, 0.999, 0.9954, 1.0])
- assert np.abs(image_slice - expected_slice).max() < 0.01
-
- def test_amused_256_fp16(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256", torch_dtype=torch.float16, variant="fp16")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.998, 0.998, 0.994, 0.9944, 0.996, 0.9908, 1.0, 1.0, 0.9986])
- assert np.abs(image_slice - expected_slice).max() < 0.01
-
- def test_amused_512(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.2809, 0.1879, 0.2027, 0.2418, 0.1852, 0.2145, 0.2484, 0.2425, 0.2317])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_512_fp16(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.2795, 0.1867, 0.2028, 0.2450, 0.1856, 0.2140, 0.2473, 0.2406, 0.2313])
- assert np.abs(image_slice - expected_slice).max() < 0.1
diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py
deleted file mode 100644
index 645379a7eab1..000000000000
--- a/tests/pipelines/amused/test_amused_inpaint.py
+++ /dev/null
@@ -1,251 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedInpaintPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedInpaintPipeline
- params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"}
- batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=32,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=32,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device)
- mask_image = torch.full((1, 1, 4, 4), 1.0, dtype=torch.float32, device=device)
- mask_image[0, 0, 0, 0] = 0
- mask_image[0, 0, 0, 1] = 0
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "image": image,
- "mask_image": mask_image,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self):
- ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedInpaintPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((256, 256))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0699, 0.0716, 0.0608, 0.0715, 0.0797, 0.0638, 0.0802, 0.0924, 0.0634])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_256_fp16(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((256, 256))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0735, 0.0749, 0.065, 0.0739, 0.0805, 0.0667, 0.0802, 0.0923, 0.0622])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_512(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((512, 512))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0005, 0.0])
- assert np.abs(image_slice - expected_slice).max() < 0.05
-
- def test_amused_512_fp16(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((512, 512))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0227, 0.0157, 0.0098, 0.0213, 0.0250, 0.0127, 0.0280, 0.0380, 0.0095])
- assert np.abs(image_slice - expected_slice).max() < 0.003
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
index 4088d46df5b2..8d4cd4cf2c1a 100644
--- a/tests/pipelines/animatediff/test_animatediff.py
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -19,7 +19,8 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
@@ -27,7 +28,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py
index 7bde663b111e..4b0eb01d067c 100644
--- a/tests/pipelines/animatediff/test_animatediff_controlnet.py
+++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py
@@ -21,8 +21,8 @@
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py
index f9686ec005f7..b5dcd8779623 100644
--- a/tests/pipelines/animatediff/test_animatediff_sdxl.py
+++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py
@@ -14,8 +14,8 @@
UNetMotionModel,
)
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
index 3e33326c8a87..6b9f672cc4a1 100644
--- a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
+++ b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
@@ -20,8 +20,8 @@
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py
index bc771e148eb2..1adb13dc4cc5 100644
--- a/tests/pipelines/animatediff/test_animatediff_video2video.py
+++ b/tests/pipelines/animatediff/test_animatediff_video2video.py
@@ -19,8 +19,8 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
index 3babbbe4ba11..c71c8c8817dc 100644
--- a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
+++ b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
@@ -20,8 +20,8 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py
deleted file mode 100644
index aaf44985aafd..000000000000
--- a/tests/pipelines/audioldm/test_audioldm.py
+++ /dev/null
@@ -1,461 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from transformers import (
- ClapTextConfig,
- ClapTextModelWithProjection,
- RobertaTokenizer,
- SpeechT5HifiGan,
- SpeechT5HifiGanConfig,
-)
-
-from diffusers import (
- AudioLDMPipeline,
- AutoencoderKL,
- DDIMScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device
-
-from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AudioLDMPipeline
- params = TEXT_TO_AUDIO_PARAMS
- batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "num_waveforms_per_prompt",
- "generator",
- "latents",
- "output_type",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(8, 16),
- layers_per_block=1,
- norm_num_groups=8,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=(8, 16),
- class_embed_type="simple_projection",
- projection_class_embeddings_input_dim=8,
- class_embeddings_concat=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[8, 16],
- in_channels=1,
- out_channels=1,
- norm_num_groups=8,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = ClapTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = ClapTextModelWithProjection(text_encoder_config)
- tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
-
- vocoder_config = SpeechT5HifiGanConfig(
- model_in_dim=8,
- sampling_rate=16000,
- upsample_initial_channel=16,
- upsample_rates=[2, 2],
- upsample_kernel_sizes=[4, 4],
- resblock_kernel_sizes=[3, 7],
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
- normalize_before=False,
- )
-
- vocoder = SpeechT5HifiGan(vocoder_config)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "vocoder": vocoder,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- }
- return inputs
-
- def test_audioldm_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = audioldm_pipe(**inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0050, 0.0050, -0.0060, 0.0033, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0033]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-2
-
- def test_audioldm_prompt_embeds(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = audioldm_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=audioldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = audioldm_pipe.text_encoder(
- text_inputs,
- )
- prompt_embeds = prompt_embeds.text_embeds
- # additional L_2 normalization over each hidden-state
- prompt_embeds = F.normalize(prompt_embeds, dim=-1)
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_audioldm_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- embeds = []
- for p in [prompt, negative_prompt]:
- text_inputs = audioldm_pipe.tokenizer(
- p,
- padding="max_length",
- max_length=audioldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- text_embeds = audioldm_pipe.text_encoder(
- text_inputs,
- )
- text_embeds = text_embeds.text_embeds
- # additional L_2 normalization over each hidden-state
- text_embeds = F.normalize(text_embeds, dim=-1)
-
- embeds.append(text_embeds)
-
- inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_audioldm_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "egg cracking"
- output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0051, 0.0050, -0.0060, 0.0034, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0032]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-2
-
- def test_audioldm_num_waveforms_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A hammer hitting a wooden surface"
-
- # test num_waveforms_per_prompt=1 (default)
- audios = audioldm_pipe(prompt, num_inference_steps=2).audios
-
- assert audios.shape == (1, 256)
-
- # test num_waveforms_per_prompt=1 (default) for batch of prompts
- batch_size = 2
- audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
-
- assert audios.shape == (batch_size, 256)
-
- # test num_waveforms_per_prompt for single prompt
- num_waveforms_per_prompt = 2
- audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
-
- assert audios.shape == (num_waveforms_per_prompt, 256)
-
- # test num_waveforms_per_prompt for batch of prompts
- batch_size = 2
- audios = audioldm_pipe(
- [prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
- ).audios
-
- assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
-
- def test_audioldm_audio_length_in_s(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
- vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate
-
- inputs = self.get_dummy_inputs(device)
- output = audioldm_pipe(audio_length_in_s=0.016, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.016
-
- output = audioldm_pipe(audio_length_in_s=0.032, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.032
-
- def test_audioldm_vocoder_model_in_dim(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = ["hey"]
-
- output = audioldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- assert audio_shape == (1, 256)
-
- config = audioldm_pipe.vocoder.config
- config.model_in_dim *= 2
- audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
- output = audioldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- # waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
- assert audio_shape == (1, 256)
-
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
-
-
-@nightly
-class AudioLDMPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_audioldm(self):
- audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 25
- audio = audioldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81920
-
- audio_slice = audio[77230:77240]
- expected_slice = np.array(
- [-0.4884, -0.4607, 0.0023, 0.5007, 0.5896, 0.5151, 0.3813, -0.0208, -0.3687, -0.4315]
- )
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-2
-
-
-@nightly
-class AudioLDMPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_audioldm_lms(self):
- audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
- audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- audio = audioldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81920
-
- audio_slice = audio[27780:27790]
- expected_slice = np.array([-0.2131, -0.0873, -0.0124, -0.0189, 0.0569, 0.1373, 0.1883, 0.2886, 0.3297, 0.2212])
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 3e-2
diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py
index 66052392f07f..5ccba1dabbfe 100644
--- a/tests/pipelines/audioldm2/test_audioldm2.py
+++ b/tests/pipelines/audioldm2/test_audioldm2.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,15 +18,14 @@
import unittest
import numpy as np
+import pytest
import torch
from transformers import (
- ClapAudioConfig,
ClapConfig,
ClapFeatureExtractor,
ClapModel,
- ClapTextConfig,
GPT2Config,
- GPT2Model,
+ GPT2LMHeadModel,
RobertaTokenizer,
SpeechT5HifiGan,
SpeechT5HifiGanConfig,
@@ -44,8 +43,15 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device
-
+from diffusers.utils import is_transformers_version
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ is_torch_version,
+ nightly,
+ torch_device,
+)
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -103,37 +109,35 @@ def get_dummy_components(self):
latent_channels=4,
)
torch.manual_seed(0)
- text_branch_config = ClapTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- audio_branch_config = ClapAudioConfig(
- spec_size=8,
- window_size=4,
- num_mel_bins=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- depths=[1, 1],
- num_attention_heads=[1, 1],
- num_hidden_layers=1,
- hidden_size=192,
- projection_dim=8,
- patch_size=2,
- patch_stride=2,
- patch_embed_input_channels=4,
- )
- text_encoder_config = ClapConfig.from_text_audio_configs(
- text_config=text_branch_config,
- audio_config=audio_branch_config,
- projection_dim=16,
+ text_branch_config = {
+ "bos_token_id": 0,
+ "eos_token_id": 2,
+ "hidden_size": 8,
+ "intermediate_size": 37,
+ "layer_norm_eps": 1e-05,
+ "num_attention_heads": 1,
+ "num_hidden_layers": 1,
+ "pad_token_id": 1,
+ "vocab_size": 1000,
+ "projection_dim": 8,
+ }
+ audio_branch_config = {
+ "spec_size": 8,
+ "window_size": 4,
+ "num_mel_bins": 8,
+ "intermediate_size": 37,
+ "layer_norm_eps": 1e-05,
+ "depths": [1, 1],
+ "num_attention_heads": [1, 1],
+ "num_hidden_layers": 1,
+ "hidden_size": 192,
+ "projection_dim": 8,
+ "patch_size": 2,
+ "patch_stride": 2,
+ "patch_embed_input_channels": 4,
+ }
+ text_encoder_config = ClapConfig(
+ text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
)
text_encoder = ClapModel(text_encoder_config)
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
@@ -162,7 +166,7 @@ def get_dummy_components(self):
n_ctx=99,
n_positions=99,
)
- language_model = GPT2Model(language_model_config)
+ language_model = GPT2LMHeadModel(language_model_config)
language_model.config.max_new_tokens = 8
torch.manual_seed(0)
@@ -213,6 +217,11 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.54.1"),
+ reason="Test currently fails on Transformers version 4.54.1.",
+ strict=False,
+ )
def test_audioldm2_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -305,7 +314,6 @@ def test_audioldm2_negative_prompt_embeds(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
@@ -364,6 +372,11 @@ def test_audioldm2_negative_prompt_embeds(self):
assert np.abs(audio_1 - audio_2).max() < 1e-2
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.54.1"),
+ reason="Test currently fails on Transformers version 4.54.1.",
+ strict=False,
+ )
def test_audioldm2_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -474,6 +487,11 @@ def test_dict_tuple_outputs_equivalent(self):
# increase tolerance from 1e-4 -> 3e-4 to account for large composite model
super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-4)
+ @pytest.mark.xfail(
+ condition=is_torch_version(">=", "2.7"),
+ reason="Test currently fails on PyTorch 2.7.",
+ strict=False,
+ )
def test_inference_batch_single_identical(self):
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
self._test_inference_batch_single_identical(expected_max_diff=2e-4)
@@ -516,18 +534,30 @@ def test_sequential_cpu_offload_forward_pass(self):
def test_encode_prompt_works_in_isolation(self):
pass
+ @unittest.skip("Not supported yet due to CLAPModel.")
+ def test_sequential_offload_forward_pass_twice(self):
+ pass
+
+ @unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.")
+ def test_cpu_offload_forward_pass_twice(self):
+ pass
+
+ @unittest.skip("Not supported yet. `vocoder` is not offloaded.")
+ def test_model_cpu_offload_forward_pass(self):
+ pass
+
@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
diff --git a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
index c56aeb905ac3..1eb9d1035c33 100644
--- a/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
+++ b/tests/pipelines/aura_flow/test_pipeline_aura_flow.py
@@ -106,9 +106,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -122,15 +122,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@unittest.skip("xformers attention processor does not exist for AuraFlow")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
deleted file mode 100644
index e073f55aec9e..000000000000
--- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTokenizer
-from transformers.models.blip_2.configuration_blip_2 import Blip2Config
-from transformers.models.clip.configuration_clip import CLIPTextConfig
-
-from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel
-from diffusers.utils.testing_utils import enable_full_determinism
-from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
-from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
-from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = BlipDiffusionPipeline
- params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- ]
- batch_params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- ]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "guidance_scale",
- "num_inference_steps",
- "neg_prompt",
- "guidance_scale",
- "prompt_strength",
- "prompt_reps",
- ]
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- vocab_size=1000,
- hidden_size=8,
- intermediate_size=8,
- projection_dim=8,
- num_hidden_layers=1,
- num_attention_heads=1,
- max_position_embeddings=77,
- )
- text_encoder = ContextCLIPTextModel(text_encoder_config)
-
- vae = AutoencoderKL(
- in_channels=4,
- out_channels=4,
- down_block_types=("DownEncoderBlock2D",),
- up_block_types=("UpDecoderBlock2D",),
- block_out_channels=(8,),
- norm_num_groups=8,
- layers_per_block=1,
- act_fn="silu",
- latent_channels=4,
- sample_size=8,
- )
-
- blip_vision_config = {
- "hidden_size": 8,
- "intermediate_size": 8,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "image_size": 224,
- "patch_size": 14,
- "hidden_act": "quick_gelu",
- }
-
- blip_qformer_config = {
- "vocab_size": 1000,
- "hidden_size": 8,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "intermediate_size": 8,
- "max_position_embeddings": 512,
- "cross_attention_frequency": 1,
- "encoder_hidden_size": 8,
- }
- qformer_config = Blip2Config(
- vision_config=blip_vision_config,
- qformer_config=blip_qformer_config,
- num_query_tokens=8,
- tokenizer="hf-internal-testing/tiny-random-bert",
- )
- qformer = Blip2QFormerModel(qformer_config)
-
- unet = UNet2DConditionModel(
- block_out_channels=(8, 16),
- norm_num_groups=8,
- layers_per_block=1,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- )
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- scheduler = PNDMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- set_alpha_to_one=False,
- skip_prk_steps=True,
- )
-
- vae.eval()
- qformer.eval()
- text_encoder.eval()
-
- image_processor = BlipImageProcessor()
-
- components = {
- "text_encoder": text_encoder,
- "vae": vae,
- "qformer": qformer,
- "unet": unet,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "image_processor": image_processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- np.random.seed(seed)
- reference_image = np.random.rand(32, 32, 3) * 255
- reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "swimming underwater",
- "generator": generator,
- "reference_image": reference_image,
- "source_subject_category": "dog",
- "target_subject_category": "dog",
- "height": 32,
- "width": 32,
- "guidance_scale": 7.5,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_blipdiffusion(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- image = pipe(**self.get_dummy_inputs(device))[0]
- image_slice = image[0, -3:, -3:, 0]
-
- assert image.shape == (1, 16, 16, 4)
-
- expected_slice = np.array(
- [0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
- )
-
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
-
- @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/pia/__init__.py b/tests/pipelines/bria/__init__.py
similarity index 100%
rename from tests/pipelines/pia/__init__.py
rename to tests/pipelines/bria/__init__.py
diff --git a/tests/pipelines/bria/test_pipeline_bria.py b/tests/pipelines/bria/test_pipeline_bria.py
new file mode 100644
index 000000000000..844488e76f2e
--- /dev/null
+++ b/tests/pipelines/bria/test_pipeline_bria.py
@@ -0,0 +1,319 @@
+# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from huggingface_hub import hf_hub_download
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from diffusers import (
+ AutoencoderKL,
+ BriaTransformer2DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.pipelines.bria import BriaPipeline
+
+# from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
+from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ numpy_cosine_similarity_distance,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = BriaPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+ test_xformers_attention = False
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = BriaTransformer2DModel(
+ patch_size=1,
+ in_channels=16,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=8,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=None,
+ axes_dims_rope=[0, 4, 4],
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ act_fn="silu",
+ block_out_channels=(32,),
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=32,
+ shift_factor=0,
+ scaling_factor=0.13025,
+ use_post_quant_conv=True,
+ use_quant_conv=True,
+ force_upcast=False,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "negative_prompt": "bad, ugly",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ def test_bria_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+ assert max_diff > 1e-6
+
+ def test_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_torch_accelerator
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ components = self.get_dummy_components()
+ for name, module in components.items():
+ if hasattr(module, "half"):
+ components[name] = module.to(torch_device).half()
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for name, component in pipe_loaded.components.items():
+ if name == "vae":
+ continue
+ if hasattr(component, "dtype"):
+ self.assertTrue(
+ component.dtype == torch.float16,
+ f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(
+ max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
+ )
+
+ def test_bria_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(16, 16), (32, 32), (64, 64)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
+ self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [True, True, True])
+
+ def test_torch_dtype_dict(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
+ torch_dtype_dict = {"transformer": torch.bfloat16, "default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
+
+ self.assertEqual(loaded_pipe.transformer.dtype, torch.bfloat16)
+ self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16)
+ self.assertEqual(loaded_pipe.vae.dtype, torch.float16)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
+ torch_dtype_dict = {"default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
+
+ self.assertEqual(loaded_pipe.transformer.dtype, torch.float16)
+ self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16)
+ self.assertEqual(loaded_pipe.vae.dtype, torch.float16)
+
+
+@slow
+@require_torch_accelerator
+class BriaPipelineSlowTests(unittest.TestCase):
+ pipeline_class = BriaPipeline
+ repo_id = "briaai/BRIA-3.2"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_inputs(self, device, seed=0):
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ prompt_embeds = torch.load(
+ hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
+ ).to(torch_device)
+
+ return {
+ "prompt_embeds": prompt_embeds,
+ "num_inference_steps": 2,
+ "guidance_scale": 0.0,
+ "max_sequence_length": 256,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ def test_bria_inference_bf16(self):
+ pipe = self.pipeline_class.from_pretrained(
+ self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, tokenizer=None
+ )
+ pipe.to(torch_device)
+
+ inputs = self.get_inputs(torch_device)
+
+ image = pipe(**inputs).images[0]
+ image_slice = image[0, :10, :10].flatten()
+
+ expected_slice = np.array(
+ [
+ 0.59729785,
+ 0.6153719,
+ 0.595112,
+ 0.5884763,
+ 0.59366125,
+ 0.5795311,
+ 0.58325,
+ 0.58449626,
+ 0.57737637,
+ 0.58432233,
+ 0.5867875,
+ 0.57824117,
+ 0.5819089,
+ 0.5830988,
+ 0.57730293,
+ 0.57647324,
+ 0.5769151,
+ 0.57312685,
+ 0.57926565,
+ 0.5823928,
+ 0.57783926,
+ 0.57162863,
+ 0.575649,
+ 0.5745547,
+ 0.5740556,
+ 0.5799735,
+ 0.57799566,
+ 0.5715559,
+ 0.5771242,
+ 0.5773058,
+ ],
+ dtype=np.float32,
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice)
+ self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}")
diff --git a/tests/pipelines/semantic_stable_diffusion/__init__.py b/tests/pipelines/bria_fibo/__init__.py
similarity index 100%
rename from tests/pipelines/semantic_stable_diffusion/__init__.py
rename to tests/pipelines/bria_fibo/__init__.py
diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py
new file mode 100644
index 000000000000..76b41114f859
--- /dev/null
+++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py
@@ -0,0 +1,139 @@
+# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
+
+from diffusers import (
+ AutoencoderKLWan,
+ BriaFiboPipeline,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
+from tests.pipelines.test_pipelines_common import PipelineTesterMixin
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = BriaFiboPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale"])
+ batch_params = frozenset(["prompt"])
+ test_xformers_attention = False
+ test_layerwise_casting = False
+ test_group_offloading = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = BriaFiboTransformer2DModel(
+ patch_size=1,
+ in_channels=16,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=8,
+ num_attention_heads=2,
+ joint_attention_dim=64,
+ text_encoder_dim=32,
+ pooled_projection_dim=None,
+ axes_dims_rope=[0, 4, 4],
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=160,
+ decoder_base_dim=256,
+ num_res_blocks=2,
+ out_channels=12,
+ patch_size=2,
+ scale_factor_spatial=16,
+ scale_factor_temporal=4,
+ temperal_downsample=[False, True, True],
+ z_dim=16,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "{'text': 'A painting of a squirrel eating a burger'}",
+ "negative_prompt": "bad, ugly",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "output_type": "np",
+ }
+ return inputs
+
+ @unittest.skip(reason="will not be supported due to dim-fusion")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ def test_bria_fibo_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+ assert max_diff > 1e-6
+
+ def test_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (64, 64), (32, 64)]
+ for height, width in height_width_pairs:
+ expected_height = height
+ expected_width = width
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/chroma/__init__.py b/tests/pipelines/chroma/__init__.py
new file mode 100644
index 000000000000..8b137891791f
--- /dev/null
+++ b/tests/pipelines/chroma/__init__.py
@@ -0,0 +1 @@
+
diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py
new file mode 100644
index 000000000000..3edd58b75f82
--- /dev/null
+++ b/tests/pipelines/chroma/test_pipeline_chroma.py
@@ -0,0 +1,160 @@
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
+
+from ...testing_utils import torch_device
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
+
+
+class ChromaPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+):
+ pipeline_class = ChromaPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = ChromaTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ approximator_hidden_dim=32,
+ approximator_layers=1,
+ approximator_num_channels=16,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "negative_prompt": "bad, ugly",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_chroma_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
+
+ def test_chroma_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
new file mode 100644
index 000000000000..4ed1393037b9
--- /dev/null
+++ b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
@@ -0,0 +1,163 @@
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
+
+from ...testing_utils import floats_tensor, torch_device
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
+
+
+class ChromaImg2ImgPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+):
+ pipeline_class = ChromaImg2ImgPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = ChromaTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ approximator_hidden_dim=32,
+ approximator_layers=1,
+ approximator_num_channels=16,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "strength": 0.8,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_chroma_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
+
+ def test_chroma_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/stable_diffusion_gligen/__init__.py b/tests/pipelines/chronoedit/__init__.py
similarity index 100%
rename from tests/pipelines/stable_diffusion_gligen/__init__.py
rename to tests/pipelines/chronoedit/__init__.py
diff --git a/tests/pipelines/chronoedit/test_chronoedit.py b/tests/pipelines/chronoedit/test_chronoedit.py
new file mode 100644
index 000000000000..43e5b3159b1c
--- /dev/null
+++ b/tests/pipelines/chronoedit/test_chronoedit.py
@@ -0,0 +1,176 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLWan,
+ ChronoEditPipeline,
+ ChronoEditTransformer3DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+
+from ...testing_utils import enable_full_determinism
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class ChronoEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = ChronoEditPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = ChronoEditTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=32,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=32, size=32)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 5,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (5, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4525, 0.4520, 0.4485, 0.4534, 0.4523, 0.4522, 0.4529, 0.4528, 0.5022, 0.5064, 0.5011, 0.5061, 0.5028, 0.4979, 0.5117, 0.5192])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skip(
+ "ChronoEditPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
+ )
+ def test_save_load_float16(self):
+ pass
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index 388dc9ef7ec4..dca1725d8a74 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,17 +21,19 @@
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -44,7 +46,11 @@
class CogVideoXPipelineFastTests(
- PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
@@ -299,9 +305,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -315,15 +321,15 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
@@ -334,12 +340,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0)
diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
index 2e962bd247b9..097e8df7b35f 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,11 +21,11 @@
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -299,9 +299,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -315,12 +315,12 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
index cac47f1a83d4..1dd5e2ae1405 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,7 +23,8 @@
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -31,7 +32,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -270,7 +270,7 @@ def test_vae_tiling(self, expected_diff_max: float = 0.3):
generator_device = "cpu"
components = self.get_dummy_components()
- # The reason to modify it this way is because I2V Transformer limits the generation to resolutions used during initalization.
+ # The reason to modify it this way is because I2V Transformer limits the generation to resolutions used during initialization.
# This limitation comes from using learned positional embeddings which cannot be generated on-the-fly like sincos or RoPE embeddings.
# See the if-statement on "self.use_learned_positional_embeddings" in diffusers/models/embeddings.py
components["transformer"] = CogVideoXTransformer3DModel.from_config(
@@ -317,9 +317,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -333,15 +333,15 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
index 4d836cb5e2a4..3a1da7c4e7f7 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,8 +21,8 @@
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -298,9 +298,9 @@ def test_fused_qkv_projections(self):
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -314,12 +314,12 @@ def test_fused_qkv_projections(self):
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py
index 79dffd230a75..819d4b952fc7 100644
--- a/tests/pipelines/cogview3/test_cogview3plus.py
+++ b/tests/pipelines/cogview3/test_cogview3plus.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,14 +21,15 @@
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_cogview3plus(self):
generator = torch.Generator("cpu").manual_seed(0)
diff --git a/tests/pipelines/cogview4/test_cogview4.py b/tests/pipelines/cogview4/test_cogview4.py
index 2a97a0799d76..a1f0fc7a715b 100644
--- a/tests/pipelines/cogview4/test_cogview4.py
+++ b/tests/pipelines/cogview4/test_cogview4.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,8 +20,8 @@
from transformers import AutoTokenizer, GlmConfig, GlmForCausalLM
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py
index a39c17bb4f79..4fd9e536cddc 100644
--- a/tests/pipelines/consisid/test_consisid.py
+++ b/tests/pipelines/consisid/test_consisid.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,14 +23,15 @@
from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -279,7 +280,7 @@ def test_vae_tiling(self, expected_diff_max: float = 0.4):
generator_device = "cpu"
components = self.get_dummy_components()
- # The reason to modify it this way is because ConsisID Transformer limits the generation to resolutions used during initalization.
+ # The reason to modify it this way is because ConsisID Transformer limits the generation to resolutions used during initialization.
# This limitation comes from using learned positional embeddings which cannot be generated on-the-fly like sincos or RoPE embeddings.
# See the if-statement on "self.use_learned_positional_embeddings" in diffusers/models/embeddings.py
components["transformer"] = ConsisIDTransformer3DModel.from_config(
@@ -316,19 +317,19 @@ def test_vae_tiling(self, expected_diff_max: float = 0.4):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class ConsisIDPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_consisid(self):
generator = torch.Generator("cpu").manual_seed(0)
@@ -338,8 +339,8 @@ def test_consisid(self):
prompt = self.prompt
image = load_image("https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true")
- id_vit_hidden = [torch.ones([1, 2, 2])] * 1
- id_cond = torch.ones(1, 2)
+ id_vit_hidden = [torch.ones([1, 577, 1024])] * 5
+ id_cond = torch.ones(1, 1280)
videos = pipe(
image=image,
@@ -357,5 +358,5 @@ def test_consisid(self):
video = videos[0]
expected_video = torch.randn(1, 16, 480, 720, 3).numpy()
- max_diff = numpy_cosine_similarity_distance(video, expected_video)
+ max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video}"
diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py
index e255cb510c42..0ab0c0af2588 100644
--- a/tests/pipelines/consistency_models/test_consistency_models.py
+++ b/tests/pipelines/consistency_models/test_consistency_models.py
@@ -10,15 +10,17 @@
ConsistencyModelPipeline,
UNet2DModel,
)
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
+ Expectations,
+ backend_empty_cache,
enable_full_determinism,
nightly,
require_torch_2,
- require_torch_gpu,
+ require_torch_accelerator,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -168,17 +170,17 @@ def test_consistency_model_pipeline_onestep_class_cond(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class ConsistencyModelPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
generator = torch.manual_seed(seed)
@@ -264,11 +266,19 @@ def test_consistency_model_cd_multistep_flash_attn(self):
# Ensure usage of flash attention in torch 2.0
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
image = pipe(**inputs).images
+
assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314])
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]),
+ ("cuda", 7): np.array([0.1845, 0.1371, 0.1211, 0.2035, 0.1954, 0.1323, 0.1773, 0.1593, 0.1314]),
+ ("cuda", 8): np.array([0.0816, 0.0518, 0.0445, 0.0594, 0.0739, 0.0534, 0.0805, 0.0457, 0.0765]),
+ }
+ )
+ expected_slice = expected_slices.get_expectation()
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py
index bb21c9ac8dcb..b142c2baf957 100644
--- a/tests/pipelines/controlnet/test_controlnet.py
+++ b/tests/pipelines/controlnet/test_controlnet.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
import gc
import tempfile
-import traceback
import unittest
import numpy as np
@@ -33,24 +32,20 @@
)
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
- get_python_version,
- is_torch_compile,
load_image,
load_numpy,
- require_torch_2,
require_torch_accelerator,
- run_test_in_subprocess,
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
@@ -68,52 +63,6 @@
enable_full_determinism()
-# Will be run via run_test_in_subprocess
-def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
- error = None
- try:
- _ = in_queue.get(timeout=timeout)
-
- controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
-
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
- )
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- pipe.controlnet.to(memory_format=torch.channels_last)
- pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "bird"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- ).resize((512, 512))
-
- output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
- image = output.images[0]
-
- assert image.shape == (512, 512, 3)
-
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
- )
- expected_image = np.resize(expected_image, (512, 512, 3))
-
- assert np.abs(expected_image - image).max() < 1.0
-
- except Exception:
- error = f"{traceback.format_exc()}"
-
- results = {"error": error}
- out_queue.put(results, timeout=timeout)
- out_queue.join()
-
-
class ControlNetPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
@@ -1053,15 +1002,6 @@ def test_canny_guess_mode_euler(self):
expected_slice = np.array([0.1655, 0.1721, 0.1623, 0.1685, 0.1711, 0.1646, 0.1651, 0.1631, 0.1494])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- @is_torch_compile
- @require_torch_2
- @unittest.skipIf(
- get_python_version == (3, 12),
- reason="Torch Dynamo isn't yet supported for Python 3.12.",
- )
- def test_stable_diffusion_compile(self):
- run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
-
def test_v11_shuffle_global_pool_conditions(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle")
diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
deleted file mode 100644
index eedda4e21722..000000000000
--- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTokenizer
-from transformers.models.blip_2.configuration_blip_2 import Blip2Config
-from transformers.models.clip.configuration_clip import CLIPTextConfig
-
-from diffusers import (
- AutoencoderKL,
- BlipDiffusionControlNetPipeline,
- ControlNetModel,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
-from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
-from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = BlipDiffusionControlNetPipeline
- params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- "condtioning_image",
- ]
- batch_params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- "condtioning_image",
- ]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "guidance_scale",
- "num_inference_steps",
- "neg_prompt",
- "guidance_scale",
- "prompt_strength",
- "prompt_reps",
- ]
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- vocab_size=1000,
- hidden_size=16,
- intermediate_size=16,
- projection_dim=16,
- num_hidden_layers=1,
- num_attention_heads=1,
- max_position_embeddings=77,
- )
- text_encoder = ContextCLIPTextModel(text_encoder_config)
-
- vae = AutoencoderKL(
- in_channels=4,
- out_channels=4,
- down_block_types=("DownEncoderBlock2D",),
- up_block_types=("UpDecoderBlock2D",),
- block_out_channels=(32,),
- layers_per_block=1,
- act_fn="silu",
- latent_channels=4,
- norm_num_groups=16,
- sample_size=16,
- )
-
- blip_vision_config = {
- "hidden_size": 16,
- "intermediate_size": 16,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "image_size": 224,
- "patch_size": 14,
- "hidden_act": "quick_gelu",
- }
-
- blip_qformer_config = {
- "vocab_size": 1000,
- "hidden_size": 16,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "intermediate_size": 16,
- "max_position_embeddings": 512,
- "cross_attention_frequency": 1,
- "encoder_hidden_size": 16,
- }
- qformer_config = Blip2Config(
- vision_config=blip_vision_config,
- qformer_config=blip_qformer_config,
- num_query_tokens=16,
- tokenizer="hf-internal-testing/tiny-random-bert",
- )
- qformer = Blip2QFormerModel(qformer_config)
-
- unet = UNet2DConditionModel(
- block_out_channels=(4, 16),
- layers_per_block=1,
- norm_num_groups=4,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=16,
- )
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- scheduler = PNDMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- set_alpha_to_one=False,
- skip_prk_steps=True,
- )
- controlnet = ControlNetModel(
- block_out_channels=(4, 16),
- layers_per_block=1,
- in_channels=4,
- norm_num_groups=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- cross_attention_dim=16,
- conditioning_embedding_out_channels=(8, 16),
- )
-
- vae.eval()
- qformer.eval()
- text_encoder.eval()
-
- image_processor = BlipImageProcessor()
-
- components = {
- "text_encoder": text_encoder,
- "vae": vae,
- "qformer": qformer,
- "unet": unet,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "controlnet": controlnet,
- "image_processor": image_processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- np.random.seed(seed)
- reference_image = np.random.rand(32, 32, 3) * 255
- reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
- cond_image = np.random.rand(32, 32, 3) * 255
- cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA")
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "swimming underwater",
- "generator": generator,
- "reference_image": reference_image,
- "condtioning_image": cond_image,
- "source_subject_category": "dog",
- "target_subject_category": "dog",
- "height": 32,
- "width": 32,
- "guidance_scale": 7.5,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.4803, 0.3865, 0.1422, 0.6119, 0.2283, 0.6365, 0.5453, 0.5205, 0.3581])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_blipdiffusion_controlnet(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- image = pipe(**self.get_dummy_inputs(device))[0]
- image_slice = image[0, -3:, -3:, 0]
-
- assert image.shape == (1, 16, 16, 4)
- expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
-
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
-
- @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py
index 100765ee34cb..c5d438e93427 100644
--- a/tests/pipelines/controlnet/test_controlnet_img2img.py
+++ b/tests/pipelines/controlnet/test_controlnet_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,7 +35,10 @@
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
@@ -43,8 +46,6 @@
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py
index b06590e13cb6..ebbe869e9e5e 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,7 +35,10 @@
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
@@ -44,8 +47,6 @@
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
index ca05db504485..c91f2c700c15 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 Harutatsu Akiyama, Jinbin Bai, and HuggingFace Inc.
+# Copyright 2025 Harutatsu Akiyama, Jinbin Bai, and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -37,13 +37,13 @@
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py
index 503db2f574e2..42ec446dbfae 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,9 @@
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
@@ -42,8 +44,6 @@
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
index bf5da16fcbb8..bd4a233741e8 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,13 +28,13 @@
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/controlnet/test_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py
deleted file mode 100644
index c71116dc7927..000000000000
--- a/tests/pipelines/controlnet/test_flax_controlnet.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
-from diffusers.utils import is_flax_available, load_image
-from diffusers.utils.testing_utils import require_flax, slow
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard
-
-
-@slow
-@require_flax
-class FlaxControlNetPipelineIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def test_canny(self):
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
- "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16
- )
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
- )
- params["controlnet"] = controlnet_params
-
- prompts = "bird"
- num_samples = jax.device_count()
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
-
- canny_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
- processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
-
- rng = jax.random.PRNGKey(0)
- rng = jax.random.split(rng, jax.device_count())
-
- p_params = replicate(params)
- prompt_ids = shard(prompt_ids)
- processed_image = shard(processed_image)
-
- images = pipe(
- prompt_ids=prompt_ids,
- image=processed_image,
- params=p_params,
- prng_seed=rng,
- num_inference_steps=50,
- jit=True,
- ).images
- assert images.shape == (jax.device_count(), 1, 768, 512, 3)
-
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array(
- [0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]
- )
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
-
- def test_pose(self):
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
- "lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16
- )
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
- )
- params["controlnet"] = controlnet_params
-
- prompts = "Chef in the kitchen"
- num_samples = jax.device_count()
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
-
- pose_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"
- )
- processed_image = pipe.prepare_image_inputs([pose_image] * num_samples)
-
- rng = jax.random.PRNGKey(0)
- rng = jax.random.split(rng, jax.device_count())
-
- p_params = replicate(params)
- prompt_ids = shard(prompt_ids)
- processed_image = shard(processed_image)
-
- images = pipe(
- prompt_ids=prompt_ids,
- image=processed_image,
- params=p_params,
- prng_seed=rng,
- num_inference_steps=50,
- jit=True,
- ).images
- assert images.shape == (jax.device_count(), 1, 768, 512, 3)
-
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array(
- [[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]
- )
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index 9a270c2bbf07..0895d9de3581 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc and The InstantX Team.
+# Copyright 2025 HuggingFace Inc and The InstantX Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
import unittest
import numpy as np
-import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
@@ -30,16 +29,16 @@
)
from diffusers.models import FluxControlNetModel
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
- require_big_gpu_with_torch_cuda,
+ require_big_accelerator,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
@@ -178,9 +177,9 @@ def test_controlnet_flux(self):
[0.47387695, 0.63134766, 0.5605469, 0.61621094, 0.7207031, 0.7089844, 0.70410156, 0.6113281, 0.64160156]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
@@ -210,8 +209,7 @@ def test_flux_image_output_shape(self):
@nightly
-@require_big_gpu_with_torch_cuda
-@pytest.mark.big_gpu_with_torch_cuda
+@require_big_accelerator
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
index 59ccb9237819..3d8378a5786d 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
@@ -11,16 +11,12 @@
FluxControlNetModel,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
- torch_device,
-)
from diffusers.utils.torch_utils import randn_tensor
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
+from ...testing_utils import (
+ torch_device,
)
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -170,12 +166,10 @@ def test_fused_qkv_projections(self):
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
@@ -186,15 +180,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
index 94d97e9962b7..3ba475deb8a8 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
@@ -20,13 +20,13 @@
FluxControlNetModel,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
index f7b3db05c8af..bf31f2abcffb 100644
--- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
+++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc and Tencent Hunyuan Team.
+# Copyright 2025 HuggingFace Inc and Tencent Hunyuan Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,15 +28,15 @@
)
from diffusers.models import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -155,16 +155,16 @@ def test_controlnet_hunyuandit(self):
if torch_device == "xpu":
expected_slice = np.array(
- [0.6376953, 0.84375, 0.58691406, 0.48046875, 0.43652344, 0.5517578, 0.54248047, 0.5644531, 0.48217773]
+ [0.6948242, 0.89160156, 0.59375, 0.5078125, 0.57910156, 0.6035156, 0.58447266, 0.53564453, 0.52246094]
)
else:
expected_slice = np.array(
[0.6953125, 0.89208984, 0.59375, 0.5078125, 0.5786133, 0.6035156, 0.5839844, 0.53564453, 0.52246094]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
index 2cd57ce56d52..34c34b7a2ce7 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,12 +26,12 @@
StableDiffusion3ControlNetInpaintingPipeline,
)
from diffusers.models import SD3ControlNetModel
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -194,9 +194,9 @@ def test_controlnet_inpaint_sd3(self):
[0.51708984, 0.7421875, 0.4580078, 0.6435547, 0.65625, 0.43603516, 0.5151367, 0.65722656, 0.60839844]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
@unittest.skip("xFormersAttnProcessor does not work with SD3 Joint Attention")
def test_xformers_attention_forwardGenerator_pass(self):
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 84ce09acbe1a..2b6cf8d1e8be 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc and The InstantX Team.
+# Copyright 2025 HuggingFace Inc and The InstantX Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
from typing import Optional
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -30,7 +29,9 @@
)
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -38,8 +39,6 @@
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -202,9 +201,9 @@ def run_pipe(self, components, use_sd35=False):
else:
expected_slice = np.array([1.0000, 0.9072, 0.4209, 0.2744, 0.5737, 0.3840, 0.6113, 0.6250, 0.6328])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f"Expected: {expected_slice}, got: {image_slice.flatten()}"
+ )
def test_controlnet_sd3(self):
components = self.get_dummy_components()
@@ -221,7 +220,6 @@ def test_xformers_attention_forwardGenerator_pass(self):
@slow
@require_big_accelerator
-@pytest.mark.big_gpu_with_torch_cuda
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py
deleted file mode 100644
index 74af4b6775cc..000000000000
--- a/tests/pipelines/controlnet_xs/test_controlnetxs.py
+++ /dev/null
@@ -1,409 +0,0 @@
-# coding=utf-8
-# Copyright 2023 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import traceback
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- ControlNetXSAdapter,
- DDIMScheduler,
- LCMScheduler,
- StableDiffusionControlNetXSPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- is_torch_compile,
- load_image,
- load_numpy,
- require_accelerator,
- require_torch_2,
- require_torch_accelerator,
- run_test_in_subprocess,
- slow,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ...models.autoencoders.vae import (
- get_asym_autoencoder_kl_config,
- get_autoencoder_kl_config,
- get_autoencoder_tiny_config,
- get_consistency_vae_config,
-)
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- SDFunctionTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-# Will be run via run_test_in_subprocess
-def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
- error = None
- try:
- _ = in_queue.get(timeout=timeout)
-
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
- )
- pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base",
- controlnet=controlnet,
- safety_checker=None,
- torch_dtype=torch.float16,
- )
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "bird"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- ).resize((512, 512))
-
- output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
- image = output.images[0]
-
- assert image.shape == (512, 512, 3)
-
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
- )
- expected_image = np.resize(expected_image, (512, 512, 3))
-
- assert np.abs(expected_image - image).max() < 1.0
-
- except Exception:
- error = f"{traceback.format_exc()}"
-
- results = {"error": error}
- out_queue.put(results, timeout=timeout)
- out_queue.join()
-
-
-class ControlNetXSPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- SDFunctionTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionControlNetXSPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- test_attention_slicing = False
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self, time_cond_proj_dim=None):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- norm_num_groups=4,
- time_cond_proj_dim=time_cond_proj_dim,
- use_linear_projection=True,
- )
- torch.manual_seed(0)
- controlnet = ControlNetXSAdapter.from_unet(
- unet=unet,
- size_ratio=1,
- learn_time_embedding=True,
- conditioning_embedding_out_channels=(2, 2),
- )
- torch.manual_seed(0)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "controlnet": controlnet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- controlnet_embedder_scale_factor = 2
- image = randn_tensor(
- (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
- generator=generator,
- device=torch.device(device),
- )
-
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "numpy",
- "image": image,
- }
-
- return inputs
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=2e-3)
-
- def test_controlnet_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=8)
- sd_pipe = StableDiffusionControlNetXSPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 16, 16, 3)
- expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_multi_vae(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- block_out_channels = pipe.vae.config.block_out_channels
- norm_num_groups = pipe.vae.config.norm_num_groups
-
- vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
- configs = [
- get_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_consistency_vae_config(block_out_channels, norm_num_groups),
- get_autoencoder_tiny_config(block_out_channels),
- ]
-
- out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- for vae_cls, config in zip(vae_classes, configs):
- vae = vae_cls(**config)
- vae = vae.to(torch_device)
- components["vae"] = vae
- vae_pipe = self.pipeline_class(**components)
-
- # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
- # So we need to move the new pipe to device.
- vae_pipe.to(torch_device)
- vae_pipe.set_progress_bar_config(disable=None)
-
- out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- assert out_vae_np.shape == out_np.shape
-
- @require_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@slow
-@require_torch_accelerator
-class ControlNetXSPipelineSlowTests(unittest.TestCase):
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_canny(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
- )
- pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "bird"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
-
- output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
-
- image = output.images[0]
-
- assert image.shape == (768, 512, 3)
-
- original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808])
- assert np.allclose(original_image, expected_image, atol=1e-04)
-
- def test_depth(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16
- )
- pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "Stormtrooper's lecture"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
- )
-
- output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
-
- image = output.images[0]
-
- assert image.shape == (512, 512, 3)
-
- original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
- assert np.allclose(original_image, expected_image, atol=1e-04)
-
- @is_torch_compile
- @require_torch_2
- def test_stable_diffusion_compile(self):
- run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
deleted file mode 100644
index 24a8b9cd5739..000000000000
--- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
+++ /dev/null
@@ -1,393 +0,0 @@
-# coding=utf-8
-# Copyright 2023 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- ControlNetXSAdapter,
- EulerDiscreteScheduler,
- StableDiffusionXLControlNetXSPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- load_image,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ...models.autoencoders.vae import (
- get_asym_autoencoder_kl_config,
- get_autoencoder_kl_config,
- get_autoencoder_tiny_config,
- get_consistency_vae_config,
-)
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class StableDiffusionXLControlNetXSPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionXLControlNetXSPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- test_attention_slicing = False
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- use_linear_projection=True,
- norm_num_groups=4,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim)
- cross_attention_dim=8,
- )
- torch.manual_seed(0)
- controlnet = ControlNetXSAdapter.from_unet(
- unet=unet,
- size_ratio=0.5,
- learn_time_embedding=True,
- conditioning_embedding_out_channels=(2, 2),
- )
- torch.manual_seed(0)
- scheduler = EulerDiscreteScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- steps_offset=1,
- beta_schedule="scaled_linear",
- timestep_spacing="leading",
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=8,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "controlnet": controlnet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "feature_extractor": None,
- }
- return components
-
- # Copied from test_controlnet_sdxl.py
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- controlnet_embedder_scale_factor = 2
- image = randn_tensor(
- (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
- generator=generator,
- device=torch.device(device),
- )
-
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- "image": image,
- }
-
- return inputs
-
- # Copied from test_controlnet_sdxl.py
- def test_attention_slicing_forward_pass(self):
- return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- # Copied from test_controlnet_sdxl.py
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
-
- # Copied from test_controlnet_sdxl.py
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=2e-3)
-
- @unittest.skip("We test this functionality elsewhere already.")
- def test_save_load_optional_components(self):
- pass
-
- @require_torch_accelerator
- # Copied from test_controlnet_sdxl.py
- def test_stable_diffusion_xl_offloads(self):
- pipes = []
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- image_slices = []
- for pipe in pipes:
- pipe.unet.set_default_attn_processor()
-
- inputs = self.get_dummy_inputs(torch_device)
- image = pipe(**inputs).images
-
- image_slices.append(image[0, -3:, -3:, -1].flatten())
-
- assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
- assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
-
- # Copied from test_controlnet_sdxl.py
- def test_stable_diffusion_xl_multi_prompts(self):
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
-
- # forward with single prompt
- inputs = self.get_dummy_inputs(torch_device)
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with same prompt duplicated
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = inputs["prompt"]
- output = sd_pipe(**inputs)
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- # forward with different prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "different prompt"
- output = sd_pipe(**inputs)
- image_slice_3 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are not equal
- assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
-
- # manually set a negative_prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with same negative_prompt duplicated
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- inputs["negative_prompt_2"] = inputs["negative_prompt"]
- output = sd_pipe(**inputs)
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- # forward with different negative_prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- inputs["negative_prompt_2"] = "different negative prompt"
- output = sd_pipe(**inputs)
- image_slice_3 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are not equal
- assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
-
- # Copied from test_controlnetxs.py
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_multi_vae(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- block_out_channels = pipe.vae.config.block_out_channels
- norm_num_groups = pipe.vae.config.norm_num_groups
-
- vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
- configs = [
- get_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_consistency_vae_config(block_out_channels, norm_num_groups),
- get_autoencoder_tiny_config(block_out_channels),
- ]
-
- out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- for vae_cls, config in zip(vae_classes, configs):
- vae = vae_cls(**config)
- vae = vae.to(torch_device)
- components["vae"] = vae
- vae_pipe = self.pipeline_class(**components)
-
- # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
- # So we need to move the new pipe to device.
- vae_pipe.to(torch_device)
- vae_pipe.set_progress_bar_config(disable=None)
-
- out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- assert out_vae_np.shape == out_np.shape
-
-
-@slow
-@require_torch_accelerator
-class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_canny(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
- )
- pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_sequential_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "bird"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
-
- images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
-
- assert images[0].shape == (768, 512, 3)
-
- original_image = images[0, -3:, -3:, -1].flatten()
- expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224])
- assert np.allclose(original_image, expected_image, atol=1e-04)
-
- def test_depth(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16
- )
- pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_sequential_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "Stormtrooper's lecture"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
- )
-
- images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
-
- assert images[0].shape == (512, 512, 3)
-
- original_image = images[0, -3:, -3:, -1].flatten()
- expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529])
- assert np.allclose(original_image, expected_image, atol=1e-04)
diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/__init__.py b/tests/pipelines/cosmos/__init__.py
similarity index 100%
rename from tests/pipelines/stable_diffusion_gligen_text_image/__init__.py
rename to tests/pipelines/cosmos/__init__.py
diff --git a/tests/pipelines/cosmos/cosmos_guardrail.py b/tests/pipelines/cosmos/cosmos_guardrail.py
new file mode 100644
index 000000000000..4de14fbaaf9d
--- /dev/null
+++ b/tests/pipelines/cosmos/cosmos_guardrail.py
@@ -0,0 +1,47 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ===== This file is an implementation of a dummy guardrail for the fast tests =====
+
+from typing import Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin
+from diffusers.models.modeling_utils import ModelMixin
+
+
+class DummyCosmosSafetyChecker(ModelMixin, ConfigMixin):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self._dtype = torch.float32
+
+ def check_text_safety(self, prompt: str) -> bool:
+ return True
+
+ def check_video_safety(self, frames: np.ndarray) -> np.ndarray:
+ return frames
+
+ def to(self, device: Union[str, torch.device] = None, dtype: torch.dtype = None) -> None:
+ self._dtype = dtype
+
+ @property
+ def device(self) -> torch.device:
+ return None
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self._dtype
diff --git a/tests/pipelines/cosmos/test_cosmos.py b/tests/pipelines/cosmos/test_cosmos.py
new file mode 100644
index 000000000000..32eea9c98c2c
--- /dev/null
+++ b/tests/pipelines/cosmos/test_cosmos.py
@@ -0,0 +1,354 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import json
+import os
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+from .cosmos_guardrail import DummyCosmosSafetyChecker
+
+
+enable_full_determinism()
+
+
+class CosmosTextToWorldPipelineWrapper(CosmosTextToWorldPipeline):
+ @staticmethod
+ def from_pretrained(*args, **kwargs):
+ kwargs["safety_checker"] = DummyCosmosSafetyChecker()
+ return CosmosTextToWorldPipeline.from_pretrained(*args, **kwargs)
+
+
+class CosmosTextToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = CosmosTextToWorldPipelineWrapper
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = CosmosTransformer3DModel(
+ in_channels=4,
+ out_channels=4,
+ num_attention_heads=2,
+ attention_head_dim=16,
+ num_layers=2,
+ mlp_ratio=2,
+ text_embed_dim=32,
+ adaln_lora_dim=4,
+ max_size=(4, 32, 32),
+ patch_size=(1, 2, 2),
+ rope_scale=(2.0, 1.0, 1.0),
+ concat_padding_mask=True,
+ extra_pos_embed_type="learnable",
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLCosmos(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ encoder_block_out_channels=(8, 8, 8, 8),
+ decode_block_out_channels=(8, 8, 8, 8),
+ attention_resolutions=(8,),
+ resolution=64,
+ num_layers=2,
+ patch_size=4,
+ patch_type="haar",
+ scaling_factor=1.0,
+ spatial_compression_ratio=4,
+ temporal_compression_ratio=4,
+ )
+
+ torch.manual_seed(0)
+ scheduler = EDMEulerScheduler(
+ sigma_min=0.002,
+ sigma_max=80,
+ sigma_data=0.5,
+ sigma_schedule="karras",
+ num_train_timesteps=1000,
+ prediction_type="epsilon",
+ rho=7.0,
+ final_sigmas_type="sigma_min",
+ )
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ # We cannot run the Cosmos Guardrail for fast tests due to the large model size
+ "safety_checker": DummyCosmosSafetyChecker(),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "height": 32,
+ "width": 32,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.0, 0.9686, 0.8549, 0.8078, 0.0, 0.8431, 1.0, 0.4863, 0.7098, 0.1098, 0.8157, 0.4235, 0.6353, 0.2549, 0.5137, 0.5333])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ self.pipeline_class._optional_components.remove("safety_checker")
+ super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
+ self.pipeline_class._optional_components.append("safety_checker")
+
+ def test_serialization_with_variants(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ model_components = [
+ component_name
+ for component_name, component in pipe.components.items()
+ if isinstance(component, torch.nn.Module)
+ ]
+ model_components.remove("safety_checker")
+ variant = "fp16"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
+
+ with open(f"{tmpdir}/model_index.json", "r") as f:
+ config = json.load(f)
+
+ for subfolder in os.listdir(tmpdir):
+ if not os.path.isfile(subfolder) and subfolder in model_components:
+ folder_path = os.path.join(tmpdir, subfolder)
+ is_folder = os.path.isdir(folder_path) and subfolder in config
+ assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
+
+ def test_torch_dtype_dict(self):
+ components = self.get_dummy_components()
+ if not components:
+ self.skipTest("No dummy components defined.")
+
+ pipe = self.pipeline_class(**components)
+
+ specified_key = next(iter(components.keys()))
+
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
+ pipe.save_pretrained(tmpdirname, safe_serialization=False)
+ torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(
+ tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
+ )
+
+ for name, component in loaded_pipe.components.items():
+ if name == "safety_checker":
+ continue
+ if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
+ expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
+ self.assertEqual(
+ component.dtype,
+ expected_dtype,
+ f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
+ )
+
+ @unittest.skip(
+ "The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
+ "a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
+ "too large and slow to run on CI."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/cosmos/test_cosmos2_text2image.py b/tests/pipelines/cosmos/test_cosmos2_text2image.py
new file mode 100644
index 000000000000..8e3c5e4c29f4
--- /dev/null
+++ b/tests/pipelines/cosmos/test_cosmos2_text2image.py
@@ -0,0 +1,341 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import json
+import os
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ Cosmos2TextToImagePipeline,
+ CosmosTransformer3DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+from .cosmos_guardrail import DummyCosmosSafetyChecker
+
+
+enable_full_determinism()
+
+
+class Cosmos2TextToImagePipelineWrapper(Cosmos2TextToImagePipeline):
+ @staticmethod
+ def from_pretrained(*args, **kwargs):
+ kwargs["safety_checker"] = DummyCosmosSafetyChecker()
+ return Cosmos2TextToImagePipeline.from_pretrained(*args, **kwargs)
+
+
+class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Cosmos2TextToImagePipelineWrapper
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = CosmosTransformer3DModel(
+ in_channels=16,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=16,
+ num_layers=2,
+ mlp_ratio=2,
+ text_embed_dim=32,
+ adaln_lora_dim=4,
+ max_size=(4, 32, 32),
+ patch_size=(1, 2, 2),
+ rope_scale=(2.0, 1.0, 1.0),
+ concat_padding_mask=True,
+ extra_pos_embed_type="learnable",
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ # We cannot run the Cosmos Guardrail for fast tests due to the large model size
+ "safety_checker": DummyCosmosSafetyChecker(),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.4784, 0.4784, 0.4784, 0.4784, 0.4784, 0.4902, 0.4588, 0.5333])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ self.pipeline_class._optional_components.remove("safety_checker")
+ super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
+ self.pipeline_class._optional_components.append("safety_checker")
+
+ def test_serialization_with_variants(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ model_components = [
+ component_name
+ for component_name, component in pipe.components.items()
+ if isinstance(component, torch.nn.Module)
+ ]
+ model_components.remove("safety_checker")
+ variant = "fp16"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
+
+ with open(f"{tmpdir}/model_index.json", "r") as f:
+ config = json.load(f)
+
+ for subfolder in os.listdir(tmpdir):
+ if not os.path.isfile(subfolder) and subfolder in model_components:
+ folder_path = os.path.join(tmpdir, subfolder)
+ is_folder = os.path.isdir(folder_path) and subfolder in config
+ assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
+
+ def test_torch_dtype_dict(self):
+ components = self.get_dummy_components()
+ if not components:
+ self.skipTest("No dummy components defined.")
+
+ pipe = self.pipeline_class(**components)
+
+ specified_key = next(iter(components.keys()))
+
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
+ pipe.save_pretrained(tmpdirname, safe_serialization=False)
+ torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(
+ tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
+ )
+
+ for name, component in loaded_pipe.components.items():
+ if name == "safety_checker":
+ continue
+ if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
+ expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
+ self.assertEqual(
+ component.dtype,
+ expected_dtype,
+ f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
+ )
+
+ @unittest.skip(
+ "The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
+ "a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
+ "too large and slow to run on CI."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/cosmos/test_cosmos2_video2world.py b/tests/pipelines/cosmos/test_cosmos2_video2world.py
new file mode 100644
index 000000000000..b0ca0e160d98
--- /dev/null
+++ b/tests/pipelines/cosmos/test_cosmos2_video2world.py
@@ -0,0 +1,355 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import json
+import os
+import tempfile
+import unittest
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ Cosmos2VideoToWorldPipeline,
+ CosmosTransformer3DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+from .cosmos_guardrail import DummyCosmosSafetyChecker
+
+
+enable_full_determinism()
+
+
+class Cosmos2VideoToWorldPipelineWrapper(Cosmos2VideoToWorldPipeline):
+ @staticmethod
+ def from_pretrained(*args, **kwargs):
+ kwargs["safety_checker"] = DummyCosmosSafetyChecker()
+ return Cosmos2VideoToWorldPipeline.from_pretrained(*args, **kwargs)
+
+
+class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Cosmos2VideoToWorldPipelineWrapper
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image", "video"})
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = CosmosTransformer3DModel(
+ in_channels=16 + 1,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=16,
+ num_layers=2,
+ mlp_ratio=2,
+ text_embed_dim=32,
+ adaln_lora_dim=4,
+ max_size=(4, 32, 32),
+ patch_size=(1, 2, 2),
+ rope_scale=(2.0, 1.0, 1.0),
+ concat_padding_mask=True,
+ extra_pos_embed_type="learnable",
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(use_karras_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ # We cannot run the Cosmos Guardrail for fast tests due to the large model size
+ "safety_checker": DummyCosmosSafetyChecker(),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image_height = 32
+ image_width = 32
+ image = PIL.Image.new("RGB", (image_width, image_height))
+
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "height": image_height,
+ "width": image_width,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.5098, 0.5137, 0.5176, 0.5098, 0.5255, 0.5412, 0.5098, 0.5059])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_components_function(self):
+ init_components = self.get_dummy_components()
+ init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
+ pipe = self.pipeline_class(**init_components)
+ self.assertTrue(hasattr(pipe, "components"))
+ self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ self.pipeline_class._optional_components.remove("safety_checker")
+ super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
+ self.pipeline_class._optional_components.append("safety_checker")
+
+ def test_serialization_with_variants(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ model_components = [
+ component_name
+ for component_name, component in pipe.components.items()
+ if isinstance(component, torch.nn.Module)
+ ]
+ model_components.remove("safety_checker")
+ variant = "fp16"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
+
+ with open(f"{tmpdir}/model_index.json", "r") as f:
+ config = json.load(f)
+
+ for subfolder in os.listdir(tmpdir):
+ if not os.path.isfile(subfolder) and subfolder in model_components:
+ folder_path = os.path.join(tmpdir, subfolder)
+ is_folder = os.path.isdir(folder_path) and subfolder in config
+ assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
+
+ def test_torch_dtype_dict(self):
+ components = self.get_dummy_components()
+ if not components:
+ self.skipTest("No dummy components defined.")
+
+ pipe = self.pipeline_class(**components)
+
+ specified_key = next(iter(components.keys()))
+
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
+ pipe.save_pretrained(tmpdirname, safe_serialization=False)
+ torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(
+ tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
+ )
+
+ for name, component in loaded_pipe.components.items():
+ if name == "safety_checker":
+ continue
+ if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
+ expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
+ self.assertEqual(
+ component.dtype,
+ expected_dtype,
+ f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
+ )
+
+ @unittest.skip(
+ "The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
+ "a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
+ "too large and slow to run on CI."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/cosmos/test_cosmos_video2world.py b/tests/pipelines/cosmos/test_cosmos_video2world.py
new file mode 100644
index 000000000000..2633c2007ac2
--- /dev/null
+++ b/tests/pipelines/cosmos/test_cosmos_video2world.py
@@ -0,0 +1,367 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import json
+import os
+import tempfile
+import unittest
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLCosmos, CosmosTransformer3DModel, CosmosVideoToWorldPipeline, EDMEulerScheduler
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+from .cosmos_guardrail import DummyCosmosSafetyChecker
+
+
+enable_full_determinism()
+
+
+class CosmosVideoToWorldPipelineWrapper(CosmosVideoToWorldPipeline):
+ @staticmethod
+ def from_pretrained(*args, **kwargs):
+ kwargs["safety_checker"] = DummyCosmosSafetyChecker()
+ return CosmosVideoToWorldPipeline.from_pretrained(*args, **kwargs)
+
+
+class CosmosVideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = CosmosVideoToWorldPipelineWrapper
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image", "video"})
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = CosmosTransformer3DModel(
+ in_channels=4 + 1,
+ out_channels=4,
+ num_attention_heads=2,
+ attention_head_dim=16,
+ num_layers=2,
+ mlp_ratio=2,
+ text_embed_dim=32,
+ adaln_lora_dim=4,
+ max_size=(4, 32, 32),
+ patch_size=(1, 2, 2),
+ rope_scale=(2.0, 1.0, 1.0),
+ concat_padding_mask=True,
+ extra_pos_embed_type="learnable",
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLCosmos(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ encoder_block_out_channels=(8, 8, 8, 8),
+ decode_block_out_channels=(8, 8, 8, 8),
+ attention_resolutions=(8,),
+ resolution=64,
+ num_layers=2,
+ patch_size=4,
+ patch_type="haar",
+ scaling_factor=1.0,
+ spatial_compression_ratio=4,
+ temporal_compression_ratio=4,
+ )
+
+ torch.manual_seed(0)
+ scheduler = EDMEulerScheduler(
+ sigma_min=0.002,
+ sigma_max=80,
+ sigma_data=0.5,
+ sigma_schedule="karras",
+ num_train_timesteps=1000,
+ prediction_type="epsilon",
+ rho=7.0,
+ final_sigmas_type="sigma_min",
+ )
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ # We cannot run the Cosmos Guardrail for fast tests due to the large model size
+ "safety_checker": DummyCosmosSafetyChecker(),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image_height = 32
+ image_width = 32
+ image = PIL.Image.new("RGB", (image_width, image_height))
+
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "height": image_height,
+ "width": image_width,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.0, 0.8275, 0.7529, 0.7294, 0.0, 0.6, 1.0, 0.3804, 0.6667, 0.0863, 0.8784, 0.5922, 0.6627, 0.2784, 0.5725, 0.7765])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_components_function(self):
+ init_components = self.get_dummy_components()
+ init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
+ pipe = self.pipeline_class(**init_components)
+ self.assertTrue(hasattr(pipe, "components"))
+ self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ self.pipeline_class._optional_components.remove("safety_checker")
+ super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
+ self.pipeline_class._optional_components.append("safety_checker")
+
+ def test_serialization_with_variants(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ model_components = [
+ component_name
+ for component_name, component in pipe.components.items()
+ if isinstance(component, torch.nn.Module)
+ ]
+ model_components.remove("safety_checker")
+ variant = "fp16"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
+
+ with open(f"{tmpdir}/model_index.json", "r") as f:
+ config = json.load(f)
+
+ for subfolder in os.listdir(tmpdir):
+ if not os.path.isfile(subfolder) and subfolder in model_components:
+ folder_path = os.path.join(tmpdir, subfolder)
+ is_folder = os.path.isdir(folder_path) and subfolder in config
+ assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
+
+ def test_torch_dtype_dict(self):
+ components = self.get_dummy_components()
+ if not components:
+ self.skipTest("No dummy components defined.")
+
+ pipe = self.pipeline_class(**components)
+
+ specified_key = next(iter(components.keys()))
+
+ with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
+ pipe.save_pretrained(tmpdirname, safe_serialization=False)
+ torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(
+ tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
+ )
+
+ for name, component in loaded_pipe.components.items():
+ if name == "safety_checker":
+ continue
+ if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
+ expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
+ self.assertEqual(
+ component.dtype,
+ expected_dtype,
+ f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
+ )
+
+ @unittest.skip(
+ "The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
+ "a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
+ "too large and slow to run on CI."
+ )
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py
deleted file mode 100644
index 1f60c0b421f3..000000000000
--- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
-
-from ..pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = DanceDiffusionPipeline
- params = UNCONDITIONAL_AUDIO_GENERATION_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {
- "callback",
- "latents",
- "callback_steps",
- "output_type",
- "num_images_per_prompt",
- }
- batch_params = UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS
- test_attention_slicing = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet1DModel(
- block_out_channels=(32, 32, 64),
- extra_in_channels=16,
- sample_size=512,
- sample_rate=16_000,
- in_channels=2,
- out_channels=2,
- flip_sin_to_cos=True,
- use_timestep_embedding=False,
- time_embedding_type="fourier",
- mid_block_type="UNetMidBlock1D",
- down_block_types=("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
- up_block_types=("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
- )
- scheduler = IPNDMScheduler()
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "batch_size": 1,
- "generator": generator,
- "num_inference_steps": 4,
- }
- return inputs
-
- def test_dance_diffusion(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = DanceDiffusionPipeline(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = pipe(**inputs)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, components["unet"].sample_size)
- expected_slice = np.array([-0.7265, 1.0000, -0.8388, 0.1175, 0.9498, -1.0000])
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
-
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local()
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- return super().test_attention_slicing_forward_pass()
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
-
-@nightly
-@require_torch_gpu
-class PipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_dance_diffusion(self):
- device = torch_device
-
- pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, pipe.unet.config.sample_size)
- expected_slice = np.array([-0.0192, -0.0231, -0.0318, -0.0059, 0.0002, -0.0020])
-
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_dance_diffusion_fp16(self):
- device = torch_device
-
- pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, pipe.unet.config.sample_size)
- expected_slice = np.array([-0.0367, -0.0488, -0.0771, -0.0525, -0.0444, -0.0341])
-
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py
index f7e0093c515a..731635bea605 100644
--- a/tests/pipelines/ddim/test_ddim.py
+++ b/tests/pipelines/ddim/test_ddim.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,8 +19,8 @@
import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py
index 750885db2c23..04ee741d8eb8 100644
--- a/tests/pipelines/ddpm/test_ddpm.py
+++ b/tests/pipelines/ddpm/test_ddpm.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +19,8 @@
import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
+
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
enable_full_determinism()
diff --git a/tests/pipelines/deepfloyd_if/__init__.py b/tests/pipelines/deepfloyd_if/__init__.py
index 094254a61875..d47374b07e22 100644
--- a/tests/pipelines/deepfloyd_if/__init__.py
+++ b/tests/pipelines/deepfloyd_if/__init__.py
@@ -7,8 +7,8 @@
from diffusers import DDPMScheduler, UNet2DConditionModel
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.pipelines.deepfloyd_if import IFWatermarker
-from diffusers.utils.testing_utils import torch_device
+from ...testing_utils import torch_device
from ..test_pipelines_common import to_np
diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py
index 295b29f12e8c..e1870ddcbae9 100644
--- a/tests/pipelines/deepfloyd_if/test_if.py
+++ b/tests/pipelines/deepfloyd_if/test_if.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,8 +23,10 @@
)
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
+ backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
load_numpy,
@@ -36,7 +38,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
@@ -135,7 +136,7 @@ def test_if_text_to_image(self):
image = output.images[0]
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py
index da06dc355896..9d3c96052be6 100644
--- a/tests/pipelines/deepfloyd_if/test_if_img2img.py
+++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,8 +22,10 @@
from diffusers import IFImg2ImgPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
+ backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
floats_tensor,
@@ -36,7 +38,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
@@ -151,7 +152,7 @@ def test_if_img2img(self):
)
image = output.images[0]
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(
diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
index 77f2f9c7bb64..e2114910edb0 100644
--- a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,8 @@
from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py
index a62d95725774..2679e0b77690 100644
--- a/tests/pipelines/deepfloyd_if/test_if_inpainting.py
+++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,8 @@
from diffusers import IFInpaintingPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
index f98284bef646..3d64556c6e41 100644
--- a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,8 @@
from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py
index 435b0cc6ec07..fa7c0fb2e062 100644
--- a/tests/pipelines/deepfloyd_if/test_if_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,8 @@
from diffusers import IFSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py
index 30883ac4a63d..cd5c08ced3fc 100644
--- a/tests/pipelines/dit/test_dit.py
+++ b/tests/pipelines/dit/test_dit.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,8 +21,16 @@
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler
from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ load_numpy,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_torch_accelerator,
+ torch_device,
+)
from ..pipeline_params import (
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS,
@@ -107,23 +115,23 @@ def test_xformers_attention_forwardGenerator_pass(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class DiTPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_dit_256(self):
generator = torch.manual_seed(0)
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
- pipe.to("cuda")
+ pipe.to(torch_device)
words = ["vase", "umbrella", "white shark", "white wolf"]
ids = pipe.get_label_ids(words)
@@ -139,7 +147,7 @@ def test_dit_256(self):
def test_dit_512(self):
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
- pipe.to("cuda")
+ pipe.to(torch_device)
words = ["vase", "umbrella"]
ids = pipe.get_label_ids(words)
@@ -149,8 +157,10 @@ def test_dit_512(self):
for word, image in zip(words, images):
expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- f"/dit/{word}_512.npy"
+ f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy"
)
- assert np.abs((expected_image - image).max()) < 1e-1
+ expected_slice = expected_image.flatten()
+ output_slice = image.flatten()
+
+ assert numpy_cosine_similarity_distance(expected_slice, output_slice) < 1e-2
diff --git a/tests/pipelines/easyanimate/test_easyanimate.py b/tests/pipelines/easyanimate/test_easyanimate.py
index 13d5c2f49b11..5cb2a232bb87 100644
--- a/tests/pipelines/easyanimate/test_easyanimate.py
+++ b/tests/pipelines/easyanimate/test_easyanimate.py
@@ -26,14 +26,15 @@
EasyAnimateTransformer3DModel,
FlowMatchEulerDiscreteScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -47,6 +48,7 @@ class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ test_xformers_attention = False
required_optional_params = frozenset(
[
"num_inference_steps",
@@ -256,19 +258,19 @@ def test_encode_prompt_works_in_isolation(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class EasyAnimatePipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_EasyAnimate(self):
generator = torch.Generator("cpu").manual_seed(0)
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index 6a560367a5b8..74499bfa607a 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -2,7 +2,6 @@
import unittest
import numpy as np
-import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -14,7 +13,9 @@
FluxPipeline,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
nightly,
numpy_cosine_similarity_distance,
@@ -22,23 +23,25 @@
slow,
torch_device,
)
-
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
+ TaylorSeerCacheTesterMixin,
+ check_qkv_fused_layers_exist,
)
class FluxPipelineFastTests(
- unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ TaylorSeerCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
@@ -154,7 +157,7 @@ def test_flux_different_prompts(self):
# Outputs should be different here
# For some reasons, they don't show large differences
- assert max_diff > 1e-6
+ self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.")
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -170,12 +173,10 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
@@ -186,15 +187,18 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
+ ("Fusion of QKV projections shouldn't affect the outputs."),
+ )
+ self.assertTrue(
+ np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
+ ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
+ )
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
+ ("Original outputs should match when fused QKV projections are disabled."),
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
@@ -208,7 +212,11 @@ def test_flux_image_output_shape(self):
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
- assert (output_height, output_width) == (expected_height, expected_width)
+ self.assertEqual(
+ (output_height, output_width),
+ (expected_height, expected_width),
+ f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
+ )
def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
@@ -219,12 +227,13 @@ def test_flux_true_cfg(self):
inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
- assert not np.allclose(no_true_cfg_out, true_cfg_out)
+ self.assertFalse(
+ np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
+ )
@nightly
@require_big_accelerator
-@pytest.mark.big_gpu_with_torch_cuda
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -269,50 +278,25 @@ def test_flux_inference(self):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- 0.3242,
- 0.3203,
- 0.3164,
- 0.3164,
- 0.3125,
- 0.3125,
- 0.3281,
- 0.3242,
- 0.3203,
- 0.3301,
- 0.3262,
- 0.3242,
- 0.3281,
- 0.3242,
- 0.3203,
- 0.3262,
- 0.3262,
- 0.3164,
- 0.3262,
- 0.3281,
- 0.3184,
- 0.3281,
- 0.3281,
- 0.3203,
- 0.3281,
- 0.3281,
- 0.3164,
- 0.3320,
- 0.3320,
- 0.3203,
- ],
- dtype=np.float32,
+ # fmt: off
+
+ expected_slices = Expectations(
+ {
+ ("cuda", None): np.array([0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203], dtype=np.float32,),
+ ("xpu", 3): np.array([0.3301, 0.3281, 0.3359, 0.3203, 0.3203, 0.3281, 0.3281, 0.3301, 0.3340, 0.3281, 0.3320, 0.3359, 0.3281, 0.3301, 0.3320, 0.3242, 0.3301, 0.3281, 0.3242, 0.3320, 0.3320, 0.3281, 0.3320, 0.3320, 0.3262, 0.3320, 0.3301, 0.3301, 0.3359, 0.3320], dtype=np.float32,),
+ }
)
+ expected_slice = expected_slices.get_expectation()
+ # fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
-
- assert max_diff < 1e-4
+ self.assertLess(
+ max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
+ )
@slow
@require_big_accelerator
-@pytest.mark.big_gpu_with_torch_cuda
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
@@ -378,42 +362,14 @@ def test_flux_ip_adapter_inference(self):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
+ # fmt: off
expected_slice = np.array(
- [
- 0.1855,
- 0.1680,
- 0.1406,
- 0.1953,
- 0.1699,
- 0.1465,
- 0.2012,
- 0.1738,
- 0.1484,
- 0.2051,
- 0.1797,
- 0.1523,
- 0.2012,
- 0.1719,
- 0.1445,
- 0.2070,
- 0.1777,
- 0.1465,
- 0.2090,
- 0.1836,
- 0.1484,
- 0.2129,
- 0.1875,
- 0.1523,
- 0.2090,
- 0.1816,
- 0.1484,
- 0.2110,
- 0.1836,
- 0.1543,
- ],
+ [0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543],
dtype=np.float32,
)
+ # fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
-
- assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
+ self.assertLess(
+ max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
+ )
diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py
index d8293952adcb..7e966470a336 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control.py
@@ -6,13 +6,9 @@
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import torch_device
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ...testing_utils import torch_device
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -140,12 +136,10 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
@@ -156,15 +150,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_img2img.py b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py
index 966543f63aeb..e56136f2e91b 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_img2img.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py
@@ -11,8 +11,8 @@
FluxControlImg2ImgPipeline,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
index 44ce2a4dedfc..e42c5fc2aab5 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -11,15 +11,11 @@
FluxControlInpaintPipeline,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
- torch_device,
-)
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
+from ...testing_utils import (
+ torch_device,
)
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -134,12 +130,10 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
@@ -150,15 +144,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
diff --git a/tests/pipelines/flux/test_pipeline_flux_fill.py b/tests/pipelines/flux/test_pipeline_flux_fill.py
index 04d4c68db8f3..25a4a3354820 100644
--- a/tests/pipelines/flux/test_pipeline_flux_fill.py
+++ b/tests/pipelines/flux/test_pipeline_flux_fill.py
@@ -6,12 +6,12 @@
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxFillPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py
index 6d33ca721b6c..6f435760aef5 100644
--- a/tests/pipelines/flux/test_pipeline_flux_img2img.py
+++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py
@@ -6,12 +6,12 @@
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py
index 161348455ca4..6324ff236e10 100644
--- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py
@@ -6,12 +6,12 @@
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxInpaintPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext.py b/tests/pipelines/flux/test_pipeline_flux_kontext.py
new file mode 100644
index 000000000000..5c78964ea54f
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_kontext.py
@@ -0,0 +1,177 @@
+import unittest
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ FluxKontextPipeline,
+ FluxTransformer2DModel,
+)
+
+from ...testing_utils import torch_device
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+)
+
+
+class FluxKontextPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+):
+ pipeline_class = FluxKontextPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["image", "prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ image = PIL.Image.new("RGB", (32, 32), 0)
+ inputs = {
+ "image": image,
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_area": 8 * 8,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ "_auto_resize": False,
+ }
+ return inputs
+
+ def test_flux_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width, "max_area": height * width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_flux_true_cfg(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("generator")
+
+ no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ inputs["negative_prompt"] = "bad quality"
+ inputs["true_cfg_scale"] = 2.0
+ true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ assert not np.allclose(no_true_cfg_out, true_cfg_out)
diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
new file mode 100644
index 000000000000..9a2e32056dcb
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
@@ -0,0 +1,190 @@
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ FluxKontextInpaintPipeline,
+ FluxTransformer2DModel,
+)
+
+from ...testing_utils import floats_tensor, torch_device
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+)
+
+
+class FluxKontextInpaintPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+):
+ pipeline_class = FluxKontextInpaintPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["image", "prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ mask_image = torch.ones((1, 1, 32, 32)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 48,
+ "strength": 0.8,
+ "output_type": "np",
+ "_auto_resize": False,
+ }
+ return inputs
+
+ def test_flux_inpaint_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 56)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+ # Because output shape is the same as the input shape, we need to create a dummy image and mask image
+ image = floats_tensor((1, 3, height, width), rng=random.Random(0)).to(torch_device)
+ mask_image = torch.ones((1, 1, height, width)).to(torch_device)
+
+ inputs.update(
+ {
+ "height": height,
+ "width": width,
+ "max_area": height * width,
+ "image": image,
+ "mask_image": mask_image,
+ }
+ )
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_flux_true_cfg(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("generator")
+
+ no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ inputs["negative_prompt"] = "bad quality"
+ inputs["true_cfg_scale"] = 2.0
+ true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ assert not np.allclose(no_true_cfg_out, true_cfg_out)
diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py
index 2cd73a51a173..bbeee28e6a62 100644
--- a/tests/pipelines/flux/test_pipeline_flux_redux.py
+++ b/tests/pipelines/flux/test_pipeline_flux_redux.py
@@ -2,12 +2,13 @@
import unittest
import numpy as np
-import pytest
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_big_accelerator,
@@ -18,10 +19,9 @@
@slow
@require_big_accelerator
-@pytest.mark.big_gpu_with_torch_cuda
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
- repo_id = "YiYiXu/yiyi-redux" # update to "black-forest-labs/FLUX.1-Redux-dev" once PR is merged
+ repo_id = "black-forest-labs/FLUX.1-Redux-dev"
base_pipeline_class = FluxPipeline
base_repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -69,41 +69,82 @@ def test_flux_redux_inference(self):
image = pipe_base(**base_pipeline_inputs, **redux_pipeline_output).images[0]
image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- 0.30078125,
- 0.37890625,
- 0.46875,
- 0.28125,
- 0.36914062,
- 0.47851562,
- 0.28515625,
- 0.375,
- 0.4765625,
- 0.28125,
- 0.375,
- 0.48046875,
- 0.27929688,
- 0.37695312,
- 0.47851562,
- 0.27734375,
- 0.38085938,
- 0.4765625,
- 0.2734375,
- 0.38085938,
- 0.47265625,
- 0.27539062,
- 0.37890625,
- 0.47265625,
- 0.27734375,
- 0.37695312,
- 0.47070312,
- 0.27929688,
- 0.37890625,
- 0.47460938,
- ],
- dtype=np.float32,
+ expected_slices = Expectations(
+ {
+ ("cuda", 7): np.array(
+ [
+ 0.30078125,
+ 0.37890625,
+ 0.46875,
+ 0.28125,
+ 0.36914062,
+ 0.47851562,
+ 0.28515625,
+ 0.375,
+ 0.4765625,
+ 0.28125,
+ 0.375,
+ 0.48046875,
+ 0.27929688,
+ 0.37695312,
+ 0.47851562,
+ 0.27734375,
+ 0.38085938,
+ 0.4765625,
+ 0.2734375,
+ 0.38085938,
+ 0.47265625,
+ 0.27539062,
+ 0.37890625,
+ 0.47265625,
+ 0.27734375,
+ 0.37695312,
+ 0.47070312,
+ 0.27929688,
+ 0.37890625,
+ 0.47460938,
+ ],
+ dtype=np.float32,
+ ),
+ ("xpu", 3): np.array(
+ [
+ 0.20507812,
+ 0.30859375,
+ 0.3984375,
+ 0.18554688,
+ 0.30078125,
+ 0.41015625,
+ 0.19921875,
+ 0.3125,
+ 0.40625,
+ 0.19726562,
+ 0.3125,
+ 0.41601562,
+ 0.19335938,
+ 0.31445312,
+ 0.4140625,
+ 0.1953125,
+ 0.3203125,
+ 0.41796875,
+ 0.19726562,
+ 0.32421875,
+ 0.41992188,
+ 0.19726562,
+ 0.32421875,
+ 0.41992188,
+ 0.20117188,
+ 0.32421875,
+ 0.41796875,
+ 0.203125,
+ 0.32617188,
+ 0.41796875,
+ ],
+ dtype=np.float32,
+ ),
+ }
)
+ expected_slice = expected_slices.get_expectation()
+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4
diff --git a/tests/pipelines/stable_diffusion_k_diffusion/__init__.py b/tests/pipelines/flux2/__init__.py
similarity index 100%
rename from tests/pipelines/stable_diffusion_k_diffusion/__init__.py
rename to tests/pipelines/flux2/__init__.py
diff --git a/tests/pipelines/flux2/test_pipeline_flux2.py b/tests/pipelines/flux2/test_pipeline_flux2.py
new file mode 100644
index 000000000000..4404dbc51047
--- /dev/null
+++ b/tests/pipelines/flux2/test_pipeline_flux2.py
@@ -0,0 +1,190 @@
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoProcessor, Mistral3Config, Mistral3ForConditionalGeneration
+
+from diffusers import (
+ AutoencoderKLFlux2,
+ FlowMatchEulerDiscreteScheduler,
+ Flux2Pipeline,
+ Flux2Transformer2DModel,
+)
+
+from ...testing_utils import (
+ torch_device,
+)
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fused_layers_exist,
+)
+
+
+class Flux2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Flux2Pipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ supports_dduf = False
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = Flux2Transformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=16,
+ timestep_guidance_channels=256, # Hardcoded in original code
+ axes_dims_rope=[4, 4, 4, 4],
+ )
+
+ config = Mistral3Config(
+ text_config={
+ "model_type": "mistral",
+ "vocab_size": 32000,
+ "hidden_size": 16,
+ "intermediate_size": 37,
+ "max_position_embeddings": 512,
+ "num_attention_heads": 4,
+ "num_hidden_layers": 1,
+ "num_key_value_heads": 2,
+ "rms_norm_eps": 1e-05,
+ "rope_theta": 1000000000.0,
+ "sliding_window": None,
+ "bos_token_id": 2,
+ "eos_token_id": 3,
+ "pad_token_id": 4,
+ },
+ vision_config={
+ "model_type": "pixtral",
+ "hidden_size": 16,
+ "num_hidden_layers": 1,
+ "num_attention_heads": 4,
+ "intermediate_size": 37,
+ "image_size": 30,
+ "patch_size": 6,
+ "num_channels": 3,
+ },
+ bos_token_id=2,
+ eos_token_id=3,
+ pad_token_id=4,
+ model_dtype="mistral3",
+ image_seq_length=4,
+ vision_feature_layer=-1,
+ image_token_index=1,
+ )
+ torch.manual_seed(0)
+ text_encoder = Mistral3ForConditionalGeneration(config)
+ tokenizer = AutoProcessor.from_pretrained(
+ "hf-internal-testing/Mistral-Small-3.1-24B-Instruct-2503-only-processor"
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLFlux2(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "a dog is dancing",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 8,
+ "output_type": "np",
+ "text_encoder_out_layers": (1,),
+ }
+ return inputs
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ original_image_slice = image[0, -3:, -3:, -1]
+
+ # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
+ # to the pipeline level.
+ pipe.transformer.fuse_qkv_projections()
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
+ )
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_fused = image[0, -3:, -3:, -1]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ image_slice_disabled = image[0, -3:, -3:, -1]
+
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
+ ("Fusion of QKV projections shouldn't affect the outputs."),
+ )
+ self.assertTrue(
+ np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
+ ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
+ )
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
+ ("Original outputs should match when fused QKV projections are disabled."),
+ )
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ self.assertEqual(
+ (output_height, output_width),
+ (expected_height, expected_width),
+ f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
+ )
diff --git a/tests/pipelines/stable_diffusion_ldm3d/__init__.py b/tests/pipelines/hidream_image/__init__.py
similarity index 100%
rename from tests/pipelines/stable_diffusion_ldm3d/__init__.py
rename to tests/pipelines/hidream_image/__init__.py
diff --git a/tests/pipelines/hidream_image/test_pipeline_hidream.py b/tests/pipelines/hidream_image/test_pipeline_hidream.py
new file mode 100644
index 000000000000..ddf39ba4c1e6
--- /dev/null
+++ b/tests/pipelines/hidream_image/test_pipeline_hidream.py
@@ -0,0 +1,160 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import (
+ AutoTokenizer,
+ CLIPTextConfig,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ LlamaForCausalLM,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ HiDreamImagePipeline,
+ HiDreamImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = HiDreamImagePipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = PipelineTesterMixin.required_optional_params
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = HiDreamImageTransformer2DModel(
+ patch_size=2,
+ in_channels=4,
+ out_channels=4,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=8,
+ num_attention_heads=4,
+ caption_channels=[32, 16],
+ text_emb_dim=64,
+ num_routed_experts=4,
+ num_activated_experts=2,
+ axes_dims_rope=(4, 2, 2),
+ max_resolution=(32, 32),
+ llama_layers=(0, 1),
+ ).eval()
+ torch.manual_seed(0)
+ vae = AutoencoderKL(scaling_factor=0.3611, shift_factor=0.1159)
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ max_position_embeddings=128,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ text_encoder_4 = LlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
+ text_encoder_4.generation_config.pad_token_id = 1
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer_4 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ components = {
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "text_encoder_3": text_encoder_3,
+ "tokenizer_3": tokenizer_3,
+ "text_encoder_4": text_encoder_4,
+ "tokenizer_4": tokenizer_4,
+ "transformer": transformer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (128, 128, 3))
+
+ # fmt: off
+ expected_slice = np.array([0.4507, 0.5256, 0.4205, 0.5791, 0.4848, 0.4831, 0.4443, 0.5107, 0.6586, 0.3163, 0.7318, 0.5933, 0.6252, 0.5512, 0.5357, 0.5983])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(np.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-4)
diff --git a/tests/pipelines/stable_diffusion_panorama/__init__.py b/tests/pipelines/hunyuan_image_21/__init__.py
similarity index 100%
rename from tests/pipelines/stable_diffusion_panorama/__init__.py
rename to tests/pipelines/hunyuan_image_21/__init__.py
diff --git a/tests/pipelines/hunyuan_image_21/test_hunyuanimage.py b/tests/pipelines/hunyuan_image_21/test_hunyuanimage.py
new file mode 100644
index 000000000000..e4b2c686b8b1
--- /dev/null
+++ b/tests/pipelines/hunyuan_image_21/test_hunyuanimage.py
@@ -0,0 +1,290 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import (
+ ByT5Tokenizer,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLForConditionalGeneration,
+ Qwen2Tokenizer,
+ T5Config,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AdaptiveProjectedMixGuidance,
+ AutoencoderKLHunyuanImage,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanImagePipeline,
+ HunyuanImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism
+from ..test_pipelines_common import (
+ FirstBlockCacheTesterMixin,
+ PipelineTesterMixin,
+ to_np,
+)
+
+
+enable_full_determinism()
+
+
+class HunyuanImagePipelineFastTests(
+ PipelineTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
+):
+ pipeline_class = HunyuanImagePipeline
+ params = frozenset(["prompt", "height", "width"])
+ batch_params = frozenset(["prompt", "negative_prompt"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+ test_attention_slicing = False
+ supports_dduf = False
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1, guidance_embeds: bool = False):
+ torch.manual_seed(0)
+ transformer = HunyuanImageTransformer2DModel(
+ in_channels=4,
+ out_channels=4,
+ num_attention_heads=4,
+ attention_head_dim=8,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ num_refiner_layers=1,
+ patch_size=(1, 1),
+ guidance_embeds=guidance_embeds,
+ text_embed_dim=32,
+ text_embed_2_dim=32,
+ rope_axes_dim=(4, 4),
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanImage(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ block_out_channels=(32, 64, 64, 64),
+ layers_per_block=1,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ sample_size=128,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ if not guidance_embeds:
+ torch.manual_seed(0)
+ guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=2)
+ ocr_guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=3)
+ else:
+ guider = None
+ ocr_guider = None
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 32,
+ "intermediate_size": 32,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [2, 2, 4],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 32,
+ "intermediate_size": 32,
+ "num_heads": 2,
+ "out_hidden_size": 32,
+ },
+ hidden_size=32,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ torch.manual_seed(0)
+ t5_config = T5Config(
+ d_model=32,
+ d_kv=4,
+ d_ff=16,
+ num_layers=2,
+ num_heads=2,
+ relative_attention_num_buckets=8,
+ relative_attention_max_distance=32,
+ vocab_size=256,
+ feed_forward_proj="gated-gelu",
+ dense_act_fn="gelu_new",
+ is_encoder_decoder=False,
+ use_cache=False,
+ tie_word_embeddings=False,
+ )
+ text_encoder_2 = T5EncoderModel(t5_config)
+ tokenizer_2 = ByT5Tokenizer()
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "guider": guider,
+ "ocr_guider": ocr_guider,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 5,
+ "height": 16,
+ "width": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+
+ expected_slice_np = np.array(
+ [0.6252659, 0.51482046, 0.60799813, 0.59267783, 0.488082, 0.5857634, 0.523781, 0.58028054, 0.5674121]
+ )
+ output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
+
+ self.assertTrue(
+ np.abs(output_slice - expected_slice_np).max() < 1e-3,
+ f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
+ )
+
+ def test_inference_guider(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.guider = pipe.guider.new(guidance_scale=1000)
+ pipe.ocr_guider = pipe.ocr_guider.new(guidance_scale=1000)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+
+ expected_slice_np = np.array(
+ [0.61494756, 0.49616697, 0.60327923, 0.6115793, 0.49047345, 0.56977504, 0.53066164, 0.58880305, 0.5570612]
+ )
+ output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
+
+ self.assertTrue(
+ np.abs(output_slice - expected_slice_np).max() < 1e-3,
+ f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
+ )
+
+ def test_inference_with_distilled_guidance(self):
+ device = "cpu"
+
+ components = self.get_dummy_components(guidance_embeds=True)
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["distilled_guidance_scale"] = 3.5
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+
+ expected_slice_np = np.array(
+ [0.63667065, 0.5187377, 0.66757566, 0.6320319, 0.4913387, 0.54813194, 0.5335031, 0.5736143, 0.5461346]
+ )
+ output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
+
+ self.assertTrue(
+ np.abs(output_slice - expected_slice_np).max() < 1e-3,
+ f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(tile_sample_min_size=96)
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
index 5802bde87a61..27b5bde31050 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,9 +24,11 @@
CLIPTextModel,
CLIPTokenizer,
LlamaConfig,
- LlamaModel,
- LlamaTokenizer,
+ LlamaTokenizerFast,
+ LlavaConfig,
+ LlavaForConditionalGeneration,
)
+from transformers.models.clip import CLIPVisionConfig
from diffusers import (
AutoencoderKLHunyuanVideo,
@@ -34,8 +36,8 @@
HunyuanVideoImageToVideoPipeline,
HunyuanVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
@@ -116,7 +118,7 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
- llama_text_encoder_config = LlamaConfig(
+ text_config = LlamaConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=16,
@@ -124,11 +126,21 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=2,
- pad_token_id=1,
+ pad_token_id=100,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
+ vision_config = CLIPVisionConfig(
+ hidden_size=8,
+ intermediate_size=37,
+ projection_dim=32,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ image_size=224,
+ )
+ llava_text_encoder_config = LlavaConfig(vision_config, text_config, pad_token_id=100, image_token_index=101)
+
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
@@ -144,8 +156,8 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
)
torch.manual_seed(0)
- text_encoder = LlamaModel(llama_text_encoder_config)
- tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
+ text_encoder = LlavaForConditionalGeneration(llava_text_encoder_config)
+ tokenizer = LlamaTokenizerFast.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
torch.manual_seed(0)
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
@@ -153,14 +165,14 @@ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
image_processor = CLIPImageProcessor(
- crop_size=336,
+ crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
- size=336,
+ size=224,
)
components = {
@@ -190,6 +202,10 @@ def get_dummy_inputs(self, device, seed=0):
"prompt_template": {
"template": "{}",
"crop_start": 0,
+ "image_emb_len": 49,
+ "image_emb_start": 5,
+ "image_emb_end": 54,
+ "double_return_token_id": 0,
},
"generator": generator,
"num_inference_steps": 2,
@@ -197,7 +213,7 @@ def get_dummy_inputs(self, device, seed=0):
"height": image_height,
"width": image_width,
"num_frames": 9,
- "max_sequence_length": 16,
+ "max_sequence_length": 64,
"output_type": "pt",
}
return inputs
@@ -213,12 +229,19 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
# NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline
self.assertEqual(generated_video.shape, (5, 3, 16, 16))
- expected_video = torch.randn(5, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.444, 0.479, 0.4485, 0.5752, 0.3539, 0.1548, 0.2706, 0.3593, 0.5323, 0.6635, 0.6795, 0.5255, 0.5091, 0.345, 0.4276, 0.4128])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ "The generated video does not match the expected slice.",
+ )
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
index bd3190de532d..7ebe797febfa 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,8 +26,8 @@
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
@@ -192,11 +192,18 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
- expected_video = torch.randn(9, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.5832, 0.5498, 0.4839, 0.4744, 0.4515, 0.4832, 0.496, 0.563, 0.5918, 0.5979, 0.5101, 0.6168, 0.6613, 0.536, 0.55, 0.5775])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ "The generated video does not match the expected slice.",
+ )
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
index aa4f045966c3..57a6daebad1f 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,15 +26,14 @@
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- torch_device,
-)
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
+ TaylorSeerCacheTesterMixin,
to_np,
)
@@ -43,7 +42,12 @@
class HunyuanVideoPipelineFastTests(
- PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ TaylorSeerCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
@@ -201,11 +205,18 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
- expected_video = torch.randn(9, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.3946, 0.4649, 0.3196, 0.4569, 0.3312, 0.3687, 0.3216, 0.3972, 0.4469, 0.3888, 0.3929, 0.3802, 0.3479, 0.3888, 0.3825, 0.3542])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ "The generated video does not match the expected slice.",
+ )
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py
new file mode 100644
index 000000000000..51c258b15c38
--- /dev/null
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py
@@ -0,0 +1,404 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTokenizer,
+ LlamaConfig,
+ LlamaModel,
+ LlamaTokenizer,
+ SiglipImageProcessor,
+ SiglipVisionModel,
+)
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideoFramepackPipeline,
+ HunyuanVideoFramepackTransformer3DModel,
+)
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ to_np,
+)
+
+
+enable_full_determinism()
+
+
+class HunyuanVideoFramepackPipelineFastTests(
+ PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+):
+ pipeline_class = HunyuanVideoFramepackPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["image", "prompt"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = HunyuanVideoFramepackTransformer3DModel(
+ in_channels=4,
+ out_channels=4,
+ num_attention_heads=2,
+ attention_head_dim=10,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ num_refiner_layers=1,
+ patch_size=2,
+ patch_size_t=1,
+ guidance_embeds=True,
+ text_embed_dim=16,
+ pooled_projection_dim=8,
+ rope_axes_dim=(2, 4, 4),
+ image_condition_type=None,
+ has_image_proj=True,
+ image_proj_dim=32,
+ has_clean_x_embedder=True,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ down_block_types=(
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ),
+ up_block_types=(
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ layers_per_block=1,
+ act_fn="silu",
+ norm_num_groups=4,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ temporal_compression_ratio=4,
+ mid_block_add_attention=True,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ llama_text_encoder_config = LlamaConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=16,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=8,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = LlamaModel(llama_text_encoder_config)
+ tokenizer = LlamaTokenizer.from_pretrained("finetrainers/dummy-hunyaunvideo", subfolder="tokenizer")
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ feature_extractor = SiglipImageProcessor.from_pretrained(
+ "hf-internal-testing/tiny-random-SiglipVisionModel", size={"height": 30, "width": 30}
+ )
+ image_encoder = SiglipVisionModel.from_pretrained("hf-internal-testing/tiny-random-SiglipVisionModel")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "feature_extractor": feature_extractor,
+ "image_encoder": image_encoder,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image_height = 32
+ image_width = 32
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "prompt_template": {
+ "template": "{}",
+ "crop_start": 0,
+ },
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.5,
+ "height": image_height,
+ "width": image_width,
+ "num_frames": 9,
+ "latent_window_size": 3,
+ "max_sequence_length": 256,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (13, 3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.363, 0.3384, 0.3426, 0.3512, 0.3372, 0.3276, 0.417, 0.4061, 0.5221, 0.467, 0.4813, 0.4556, 0.4107, 0.3945, 0.4049, 0.4551])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ "The generated video does not match the expected slice.",
+ )
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ # Seems to require higher tolerance than the other tests
+ expected_diff_max = 0.6
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ def test_float16_inference(self, expected_max_diff=0.2):
+ # NOTE: this test needs a higher tolerance because of multiple forwards through
+ # the model, which compounds the overall fp32 vs fp16 numerical differences. It
+ # shouldn't be expected that the results are the same, so we bump the tolerance.
+ return super().test_float16_inference(expected_max_diff)
+
+ @unittest.skip("The image_encoder uses SiglipVisionModel, which does not support sequential CPU offloading.")
+ def test_sequential_cpu_offload_forward_pass(self):
+ # https://github.com/huggingface/transformers/blob/21cb353b7b4f77c6f5f5c3341d660f86ff416d04/src/transformers/models/siglip/modeling_siglip.py#L803
+ # This is because it instantiates it's attention layer from torch.nn.MultiheadAttention, which calls to
+ # `torch.nn.functional.multi_head_attention_forward` with the weights and bias. Since the hook is never
+ # triggered with a forward pass call, the weights stay on the CPU. There are more examples where we skip
+ # this test because of MHA (example: HunyuanDiT because of AttentionPooling layer).
+ pass
+
+ @unittest.skip("The image_encoder uses SiglipVisionModel, which does not support sequential CPU offloading.")
+ def test_sequential_offload_forward_pass_twice(self):
+ # https://github.com/huggingface/transformers/blob/21cb353b7b4f77c6f5f5c3341d660f86ff416d04/src/transformers/models/siglip/modeling_siglip.py#L803
+ # This is because it instantiates it's attention layer from torch.nn.MultiheadAttention, which calls to
+ # `torch.nn.functional.multi_head_attention_forward` with the weights and bias. Since the hook is never
+ # triggered with a forward pass call, the weights stay on the CPU. There are more examples where we skip
+ # this test because of MHA (example: HunyuanDiT because of AttentionPooling layer).
+ pass
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/hunyuan_video1_5/__init__.py b/tests/pipelines/hunyuan_video1_5/__init__.py
new file mode 100644
index 000000000000..8fb044d9cf83
--- /dev/null
+++ b/tests/pipelines/hunyuan_video1_5/__init__.py
@@ -0,0 +1 @@
+# Copyright 2025 The HuggingFace Team.
diff --git a/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py b/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py
new file mode 100644
index 000000000000..993c7ef6e4bb
--- /dev/null
+++ b/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py
@@ -0,0 +1,187 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from transformers import ByT5Tokenizer, Qwen2_5_VLTextConfig, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo15,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanVideo15Pipeline,
+ HunyuanVideo15Transformer3DModel,
+)
+from diffusers.guiders import ClassifierFreeGuidance
+
+from ...testing_utils import enable_full_determinism
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = HunyuanVideo15Pipeline
+ params = frozenset(
+ [
+ "prompt",
+ "negative_prompt",
+ "height",
+ "width",
+ "prompt_embeds",
+ "prompt_embeds_mask",
+ "negative_prompt_embeds",
+ "negative_prompt_embeds_mask",
+ "prompt_embeds_2",
+ "prompt_embeds_mask_2",
+ "negative_prompt_embeds_2",
+ "negative_prompt_embeds_mask_2",
+ ]
+ )
+ batch_params = ["prompt", "negative_prompt"]
+ required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
+ test_attention_slicing = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = False
+ supports_dduf = False
+
+ def get_dummy_components(self, num_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = HunyuanVideo15Transformer3DModel(
+ in_channels=9,
+ out_channels=4,
+ num_attention_heads=2,
+ attention_head_dim=8,
+ num_layers=num_layers,
+ num_refiner_layers=1,
+ mlp_ratio=2.0,
+ patch_size=1,
+ patch_size_t=1,
+ text_embed_dim=16,
+ text_embed_2_dim=32,
+ image_embed_dim=12,
+ rope_axes_dim=(2, 2, 4),
+ target_size=16,
+ task_type="t2v",
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo15(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ block_out_channels=(16, 16),
+ layers_per_block=1,
+ spatial_compression_ratio=4,
+ temporal_compression_ratio=2,
+ downsample_match_channel=False,
+ upsample_match_channel=False,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ torch.manual_seed(0)
+ qwen_config = Qwen2_5_VLTextConfig(
+ **{
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ }
+ )
+ text_encoder = Qwen2_5_VLTextModel(qwen_config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer_2 = ByT5Tokenizer()
+
+ guider = ClassifierFreeGuidance(guidance_scale=1.0)
+
+ components = {
+ "transformer": transformer.eval(),
+ "vae": vae.eval(),
+ "scheduler": scheduler,
+ "text_encoder": text_encoder.eval(),
+ "text_encoder_2": text_encoder_2.eval(),
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "guider": guider,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "monkey",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ result = pipe(**inputs)
+ video = result.frames
+
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4296, 0.5549, 0.3088, 0.9115, 0.5049, 0.7926, 0.5549, 0.8618, 0.5091, 0.5075, 0.7117, 0.5292, 0.7053, 0.4864, 0.5206, 0.3878])
+ # fmt: on
+
+ self.assertTrue(
+ torch.abs(generated_slice - expected_slice).max() < 1e-3,
+ f"output_slice: {generated_slice}, expected_slice: {expected_slice}",
+ )
+
+ @unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("Needs to be revisited.")
+ def test_inference_batch_consistent(self):
+ super().test_inference_batch_consistent()
+
+ @unittest.skip("Needs to be revisited.")
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical()
diff --git a/tests/pipelines/hunyuandit/test_hunyuan_dit.py b/tests/pipelines/hunyuandit/test_hunyuan_dit.py
index 5b1a82eda227..2a329f10bc80 100644
--- a/tests/pipelines/hunyuandit/test_hunyuan_dit.py
+++ b/tests/pipelines/hunyuandit/test_hunyuan_dit.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,20 +21,16 @@
import torch
from transformers import AutoTokenizer, BertModel, T5EncoderModel
-from diffusers import (
- AutoencoderKL,
- DDPMScheduler,
- HunyuanDiT2DModel,
- HunyuanDiTPipeline,
-)
-from diffusers.utils.testing_utils import (
+from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -128,14 +124,22 @@ def test_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
- @unittest.skip("Not supported.")
+ @unittest.skip("The HunyuanDiT Attention pooling layer does not support sequential CPU offloading.")
def test_sequential_cpu_offload_forward_pass(self):
# TODO(YiYi) need to fix later
+ # This is because it instantiates it's attention layer from torch.nn.MultiheadAttention, which calls to
+ # `torch.nn.functional.multi_head_attention_forward` with the weights and bias. Since the hook is never
+ # triggered with a forward pass call, the weights stay on the CPU. There are more examples where we skip
+ # this test because of MHA (example: HunyuanVideo Framepack)
pass
- @unittest.skip("Not supported.")
+ @unittest.skip("The HunyuanDiT Attention pooling layer does not support sequential CPU offloading.")
def test_sequential_offload_forward_pass_twice(self):
# TODO(YiYi) need to fix later
+ # This is because it instantiates it's attention layer from torch.nn.MultiheadAttention, which calls to
+ # `torch.nn.functional.multi_head_attention_forward` with the weights and bias. Since the hook is never
+ # triggered with a forward pass call, the weights stay on the CPU. There are more examples where we skip
+ # this test because of MHA (example: HunyuanVideo Framepack)
pass
def test_inference_batch_single_identical(self):
@@ -179,9 +183,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -197,15 +201,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@unittest.skip(
"Test not supported as `encode_prompt` is called two times separately which deivates from about 99% of the pipelines we have."
@@ -315,12 +319,12 @@ class HunyuanDiTPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_hunyuan_dit_1024(self):
generator = torch.Generator("cpu").manual_seed(0)
diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
deleted file mode 100644
index 868a40c9fb53..000000000000
--- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextConfig,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- I2VGenXLPipeline,
-)
-from diffusers.models.unets import I2VGenXLUNet
-from diffusers.utils import is_xformers_available, load_image
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- slow,
- torch_device,
-)
-
-from ..test_pipelines_common import PipelineTesterMixin, SDFunctionTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unittest.TestCase):
- pipeline_class = I2VGenXLPipeline
- params = frozenset(["prompt", "negative_prompt", "image"])
- batch_params = frozenset(["prompt", "negative_prompt", "image", "generator"])
- # No `output_type`.
- required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
-
- supports_dduf = False
- test_layerwise_casting = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- torch.manual_seed(0)
- unet = I2VGenXLUNet(
- block_out_channels=(4, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=4,
- attention_head_dim=4,
- num_attention_heads=None,
- norm_num_groups=2,
- )
-
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=(8,),
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=16,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- torch.manual_seed(0)
- vision_encoder_config = CLIPVisionConfig(
- hidden_size=4,
- projection_dim=4,
- num_hidden_layers=2,
- num_attention_heads=2,
- image_size=32,
- intermediate_size=16,
- patch_size=1,
- )
- image_encoder = CLIPVisionModelWithProjection(vision_encoder_config)
-
- torch.manual_seed(0)
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "image_encoder": image_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": feature_extractor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "image": input_image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- "num_frames": 4,
- "width": 32,
- "height": 32,
- }
- return inputs
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = pipe(**inputs).frames
-
- image_slice = frames[0][0][-3:, -3:, -1]
-
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.5146, 0.6525, 0.6032, 0.5204, 0.5675, 0.4125, 0.3016, 0.5172, 0.4095])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_save_load_local(self):
- super().test_save_load_local(expected_max_difference=0.006)
-
- def test_sequential_cpu_offload_forward_pass(self):
- super().test_sequential_cpu_offload_forward_pass(expected_max_diff=0.008)
-
- def test_dict_tuple_outputs_equivalent(self):
- super().test_dict_tuple_outputs_equivalent(expected_max_difference=0.008)
-
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=0.008)
-
- @unittest.skip("Deprecated functionality")
- def test_attention_slicing_forward_pass(self):
- pass
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=1e-2)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=0.008)
-
- def test_model_cpu_offload_forward_pass(self):
- super().test_model_cpu_offload_forward_pass(expected_max_diff=0.008)
-
- def test_num_videos_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = pipe(**inputs, num_videos_per_prompt=2).frames
-
- assert frames.shape == (2, 4, 32, 32, 3)
- assert frames[0][0].shape == (32, 32, 3)
-
- image_slice = frames[0][0][-3:, -3:, -1]
- expected_slice = np.array([0.5146, 0.6525, 0.6032, 0.5204, 0.5675, 0.4125, 0.3016, 0.5172, 0.4095])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @unittest.skip("Test not supported for now.")
- def test_encode_prompt_works_in_isolation(self):
- pass
-
-
-@slow
-@require_torch_accelerator
-class I2VGenXLPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_i2vgen_xl(self):
- pipe = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
- )
-
- generator = torch.Generator("cpu").manual_seed(0)
- num_frames = 3
-
- output = pipe(
- image=image,
- prompt="my cat",
- num_frames=num_frames,
- generator=generator,
- num_inference_steps=3,
- output_type="np",
- )
-
- image = output.frames[0]
- assert image.shape == (num_frames, 704, 1280, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5482, 0.6244, 0.6274, 0.4584, 0.5935, 0.5937, 0.4579, 0.5767, 0.5892])
- assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3
diff --git a/tests/pipelines/stable_diffusion_safe/__init__.py b/tests/pipelines/ip_adapters/__init__.py
similarity index 100%
rename from tests/pipelines/stable_diffusion_safe/__init__.py
rename to tests/pipelines/ip_adapters/__init__.py
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
index d5d4c20e471f..32590111cdf3 100644
--- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
+++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,7 +33,9 @@
)
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
is_flaky,
@@ -664,7 +666,50 @@ def test_instant_style_multiple_masks(self):
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
- expected_slice = np.array([0.2323, 0.1026, 0.1338, 0.0638, 0.0662, 0.0000, 0.0000, 0.0000, 0.0199])
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.2520,
+ 0.1050,
+ 0.1510,
+ 0.0997,
+ 0.0893,
+ 0.0019,
+ 0.0000,
+ 0.0000,
+ 0.0210,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.2323,
+ 0.1026,
+ 0.1338,
+ 0.0638,
+ 0.0662,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0199,
+ ]
+ ),
+ ("cuda", 8): np.array(
+ [
+ 0.2518,
+ 0.1059,
+ 0.1553,
+ 0.0977,
+ 0.0852,
+ 0.0000,
+ 0.0000,
+ 0.0000,
+ 0.0220,
+ ]
+ ),
+ }
+ )
+ expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py
index 30144e37a9d4..6207e71df8cd 100644
--- a/tests/pipelines/kandinsky/test_kandinsky.py
+++ b/tests/pipelines/kandinsky/test_kandinsky.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,12 +18,15 @@
import unittest
import numpy as np
+import pytest
import torch
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
-from diffusers.utils.testing_utils import (
+from diffusers.utils import is_transformers_version
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -32,7 +35,6 @@
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -215,6 +217,11 @@ def get_dummy_inputs(self, device, seed=0):
dummy = Dummies()
return dummy.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
@@ -240,12 +247,12 @@ def test_kandinsky(self):
expected_slice = np.array([1.0000, 1.0000, 0.2766, 1.0000, 0.5447, 0.1737, 1.0000, 0.4316, 0.9024])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index c5f27a9cc9a9..eba897659700 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,10 +16,12 @@
import unittest
import numpy as np
+import pytest
from diffusers import KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyInpaintCombinedPipeline
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
+from diffusers.utils import is_transformers_version
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
from .test_kandinsky import Dummies
from .test_kandinsky_img2img import Dummies as Img2ImgDummies
@@ -73,6 +75,11 @@ def get_dummy_inputs(self, device, seed=0):
)
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
@@ -98,12 +105,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.2893, 0.1464, 0.4603, 0.3529, 0.4612, 0.7701, 0.4027, 0.3051, 0.5155])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -181,6 +188,11 @@ def get_dummy_inputs(self, device, seed=0):
inputs.pop("negative_image_embeds")
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
@@ -206,12 +218,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.4852, 0.4136, 0.4539, 0.4781, 0.4680, 0.5217, 0.4973, 0.4089, 0.4977])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -292,6 +304,11 @@ def get_dummy_inputs(self, device, seed=0):
inputs.pop("negative_image_embeds")
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
@@ -318,12 +335,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
index 26361ce18b82..6d1b43a24fd9 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
import unittest
import numpy as np
+import pytest
import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
@@ -31,7 +32,9 @@
VQModel,
)
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
-from diffusers.utils.testing_utils import (
+from diffusers.utils import is_transformers_version
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -42,7 +45,6 @@
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -237,6 +239,11 @@ def get_dummy_inputs(self, device, seed=0):
dummies = Dummies()
return dummies.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky_img2img(self):
device = "cpu"
@@ -261,12 +268,12 @@ def test_kandinsky_img2img(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5816, 0.5872, 0.4634, 0.5982, 0.4767, 0.4710, 0.4669, 0.4717, 0.4966])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -321,7 +328,7 @@ def test_kandinsky_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
@@ -387,7 +394,7 @@ def test_kandinsky_img2img_ddpm(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/frog.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/frog.png"
)
prompt = "A red cartoon frog, 4k"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
index e30c601b6011..e2f4aa2a4f14 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,13 +18,16 @@
import unittest
import numpy as np
+import pytest
import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyInpaintPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
-from diffusers.utils.testing_utils import (
+from diffusers.utils import is_transformers_version
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -34,7 +37,6 @@
require_torch_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -231,6 +233,11 @@ def get_dummy_inputs(self, device, seed=0):
dummies = Dummies()
return dummies.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky_inpaint(self):
device = "cpu"
@@ -256,12 +263,12 @@ def test_kandinsky_inpaint(self):
expected_slice = np.array([0.8222, 0.8896, 0.4373, 0.8088, 0.4905, 0.2609, 0.6816, 0.4291, 0.5129])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -319,7 +326,7 @@ def test_kandinsky_inpaint(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
mask[:250, 250:-250] = 1
diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py
index abb53bfb792f..903a1e5decfa 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_prior.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,8 +28,8 @@
)
from diffusers import KandinskyPriorPipeline, PriorTransformer, UnCLIPScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
+from ...testing_utils import enable_full_determinism, skip_mps, torch_device
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py
index fea49d47b7bb..38294aa4c111 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,8 @@
import torch
from diffusers import DDIMScheduler, KandinskyV22Pipeline, KandinskyV22PriorPipeline, UNet2DConditionModel, VQModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -31,7 +32,6 @@
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -210,13 +210,13 @@ def test_kandinsky(self):
expected_slice = np.array([0.3420, 0.9505, 0.3919, 1.0000, 0.5188, 0.3109, 0.6139, 0.5624, 0.6811])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
index 90f8b2034109..62f5853da9a5 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,8 +22,8 @@
KandinskyV22Img2ImgCombinedPipeline,
KandinskyV22InpaintCombinedPipeline,
)
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
+from ...testing_utils import enable_full_determinism, require_accelerator, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
from .test_kandinsky import Dummies
from .test_kandinsky_img2img import Dummies as Img2ImgDummies
@@ -103,12 +103,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.3076, 0.2729, 0.5668, 0.0522, 0.3384, 0.7028, 0.4908, 0.3659, 0.6243])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -227,12 +227,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.4445, 0.4287, 0.4596, 0.3919, 0.3730, 0.5039, 0.4834, 0.4269, 0.5521])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -350,12 +350,12 @@ def test_kandinsky(self):
expected_slice = np.array([0.5039, 0.4926, 0.4898, 0.4978, 0.4838, 0.4942, 0.4738, 0.4702, 0.4816])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
@@ -388,7 +388,7 @@ def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=5e-1)
+ super().test_float16_inference(expected_max_diff=8e-1)
def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4)
@@ -402,6 +402,7 @@ def test_save_load_local(self):
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
+ @require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
index 1f3219e0d69e..8f8e58a8c4c8 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,16 +27,19 @@
UNet2DConditionModel,
VQModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
+ torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -210,13 +213,13 @@ def test_kandinsky_controlnet(self):
[0.6959826, 0.868279, 0.7558092, 0.68769467, 0.85805804, 0.65977496, 0.44885302, 0.5959111, 0.4251595]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
@@ -226,19 +229,19 @@ def test_inference_batch_single_identical(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_controlnet(self):
expected_image = load_numpy(
@@ -287,6 +290,12 @@ def test_kandinsky_controlnet(self):
image = output.images[0]
assert image.shape == (512, 512, 3)
-
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
- assert max_diff < 1e-4
+ expected_max_diffs = Expectations(
+ {
+ ("xpu", 3): 2e-3,
+ ("cuda", 7): 2e-4,
+ }
+ )
+ expected_max_diff = expected_max_diffs.get_expectation()
+ assert max_diff < expected_max_diff
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
index 20944aa3d6f8..a4346605929b 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,16 +28,18 @@
UNet2DConditionModel,
VQModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
+ torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -218,12 +220,12 @@ def test_kandinsky_controlnet_img2img(self):
expected_slice = np.array(
[0.54985034, 0.55509365, 0.52561504, 0.5570494, 0.5593818, 0.5263979, 0.50285643, 0.5069846, 0.51196736]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1.75e-3)
@@ -233,19 +235,19 @@ def test_float16_inference(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class KandinskyV22ControlnetImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_kandinsky_controlnet_img2img(self):
expected_image = load_numpy(
@@ -254,7 +256,7 @@ def test_kandinsky_controlnet_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
init_image = init_image.resize((512, 512))
@@ -309,4 +311,4 @@ def test_kandinsky_controlnet_img2img(self):
assert image.shape == (512, 512, 3)
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
- assert max_diff < 1e-4
+ assert max_diff < 5e-4
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
index 4702f473a992..99f3fe0f40f1 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,7 +28,8 @@
UNet2DConditionModel,
VQModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -39,7 +40,6 @@
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -228,12 +228,12 @@ def test_kandinsky_img2img(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5712, 0.5443, 0.4725, 0.6195, 0.5184, 0.4651, 0.4473, 0.4590, 0.5016])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=2e-1)
@@ -261,7 +261,7 @@ def test_kandinsky_img2img(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
prompt = "A red cartoon frog, 4k"
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
index 9a7f659e533c..8a693e9c2dd0 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,7 +28,8 @@
UNet2DConditionModel,
VQModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -36,11 +37,11 @@
load_image,
load_numpy,
numpy_cosine_similarity_distance,
+ require_accelerator,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -234,12 +235,12 @@ def test_kandinsky_inpaint(self):
[0.50775903, 0.49527195, 0.48824543, 0.50192237, 0.48644906, 0.49373814, 0.4780598, 0.47234827, 0.48327848]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
@@ -254,6 +255,7 @@ def test_model_cpu_offload_forward_pass(self):
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
+ @require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
@@ -314,7 +316,7 @@ def test_kandinsky_inpaint(self):
)
init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
)
mask = np.zeros((768, 768), dtype=np.float32)
mask[:250, 250:-250] = 1
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
index bdec6c132f80..adcc6cc2167c 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,8 +29,8 @@
)
from diffusers import KandinskyV22PriorPipeline, PriorTransformer, UnCLIPScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
+from ...testing_utils import enable_full_determinism, skip_mps, torch_device
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
index 0ea32981d518..5377d917791a 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,13 +30,13 @@
)
from diffusers import KandinskyV22PriorEmb2EmbPipeline, PriorTransformer, UnCLIPScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py
index af1d45ff8975..55500f729bbb 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,7 +30,8 @@
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
@@ -38,7 +39,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -157,9 +157,9 @@ def test_kandinsky3(self):
expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
index e00948621a06..503fdb242dff 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,7 +30,8 @@
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -39,7 +40,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -181,9 +181,9 @@ def test_kandinsky3_img2img(self):
[0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
diff --git a/tests/pipelines/stable_diffusion_sag/__init__.py b/tests/pipelines/kandinsky5/__init__.py
similarity index 100%
rename from tests/pipelines/stable_diffusion_sag/__init__.py
rename to tests/pipelines/kandinsky5/__init__.py
diff --git a/tests/pipelines/kandinsky5/test_kandinsky5.py b/tests/pipelines/kandinsky5/test_kandinsky5.py
new file mode 100644
index 000000000000..4101e7798dea
--- /dev/null
+++ b/tests/pipelines/kandinsky5/test_kandinsky5.py
@@ -0,0 +1,210 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from transformers import (
+ AutoProcessor,
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTokenizer,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLForConditionalGeneration,
+)
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ Kandinsky5T2VPipeline,
+ Kandinsky5Transformer3DModel,
+)
+
+from ...testing_utils import (
+ enable_full_determinism,
+)
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class Kandinsky5T2VPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Kandinsky5T2VPipeline
+
+ batch_params = ["prompt", "negative_prompt"]
+
+ params = frozenset(["prompt", "height", "width", "num_frames", "num_inference_steps", "guidance_scale"])
+
+ required_optional_params = {
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ "max_sequence_length",
+ }
+ test_xformers_attention = False
+ supports_optional_components = True
+ supports_dduf = False
+ test_attention_slicing = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo(
+ act_fn="silu",
+ block_out_channels=[32, 64],
+ down_block_types=[
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ],
+ in_channels=3,
+ latent_channels=16,
+ layers_per_block=1,
+ mid_block_add_attention=False,
+ norm_num_groups=32,
+ out_channels=3,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ temporal_compression_ratio=4,
+ up_block_types=[
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ],
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ qwen_hidden_size = 32
+ torch.manual_seed(0)
+ qwen_config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": qwen_hidden_size,
+ "intermediate_size": qwen_hidden_size,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [2, 2, 4],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": qwen_hidden_size,
+ "intermediate_size": qwen_hidden_size,
+ "num_heads": 2,
+ "out_hidden_size": qwen_hidden_size,
+ },
+ hidden_size=qwen_hidden_size,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(qwen_config)
+ tokenizer = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ clip_hidden_size = 16
+ torch.manual_seed(0)
+ clip_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=clip_hidden_size,
+ intermediate_size=16,
+ layer_norm_eps=1e-05,
+ num_attention_heads=2,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ projection_dim=clip_hidden_size,
+ )
+ text_encoder_2 = CLIPTextModel(clip_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ transformer = Kandinsky5Transformer3DModel(
+ in_visual_dim=16,
+ in_text_dim=qwen_hidden_size,
+ in_text_dim2=clip_hidden_size,
+ time_dim=16,
+ out_visual_dim=16,
+ patch_size=(1, 2, 2),
+ model_dim=16,
+ ff_dim=32,
+ num_text_blocks=1,
+ num_visual_blocks=2,
+ axes_dims=(1, 1, 2),
+ visual_cond=False,
+ attention_type="regular",
+ )
+
+ return {
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "scheduler": scheduler,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ return {
+ "prompt": "a red square",
+ "height": 32,
+ "width": 32,
+ "num_frames": 5,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.0,
+ "generator": generator,
+ "output_type": "pt",
+ "max_sequence_length": 8,
+ }
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = pipe(**inputs)
+ video = output.frames[0]
+
+ self.assertEqual(video.shape, (3, 3, 16, 16))
+
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("Only SDPA or NABLA (flex)")
+ def test_xformers_memory_efficient_attention(self):
+ pass
+
+ @unittest.skip("TODO:Test does not work")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("TODO: revisit")
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/kandinsky5/test_kandinsky5_i2i.py b/tests/pipelines/kandinsky5/test_kandinsky5_i2i.py
new file mode 100644
index 000000000000..dc832990836f
--- /dev/null
+++ b/tests/pipelines/kandinsky5/test_kandinsky5_i2i.py
@@ -0,0 +1,213 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from PIL import Image
+from transformers import (
+ AutoProcessor,
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTokenizer,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLForConditionalGeneration,
+)
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ Kandinsky5I2IPipeline,
+ Kandinsky5Transformer3DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class Kandinsky5I2IPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Kandinsky5I2IPipeline
+
+ batch_params = ["prompt", "negative_prompt"]
+ params = frozenset(["image", "prompt", "height", "width", "num_inference_steps", "guidance_scale"])
+
+ required_optional_params = {
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ "max_sequence_length",
+ }
+ test_xformers_attention = False
+ supports_optional_components = True
+ supports_dduf = False
+ test_attention_slicing = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ act_fn="silu",
+ block_out_channels=[32, 64, 64, 64],
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
+ force_upcast=True,
+ in_channels=3,
+ latent_channels=16,
+ layers_per_block=1,
+ mid_block_add_attention=False,
+ norm_num_groups=32,
+ out_channels=3,
+ sample_size=64,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
+ use_post_quant_conv=False,
+ use_quant_conv=False,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ qwen_hidden_size = 32
+ torch.manual_seed(0)
+ qwen_config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": qwen_hidden_size,
+ "intermediate_size": qwen_hidden_size,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [2, 2, 4],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": qwen_hidden_size,
+ "intermediate_size": qwen_hidden_size,
+ "num_heads": 2,
+ "out_hidden_size": qwen_hidden_size,
+ },
+ hidden_size=qwen_hidden_size,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(qwen_config)
+ tokenizer = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ clip_hidden_size = 16
+ torch.manual_seed(0)
+ clip_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=clip_hidden_size,
+ intermediate_size=16,
+ layer_norm_eps=1e-05,
+ num_attention_heads=2,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ projection_dim=clip_hidden_size,
+ )
+ text_encoder_2 = CLIPTextModel(clip_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ transformer = Kandinsky5Transformer3DModel(
+ in_visual_dim=16,
+ in_text_dim=qwen_hidden_size,
+ in_text_dim2=clip_hidden_size,
+ time_dim=16,
+ out_visual_dim=16,
+ patch_size=(1, 2, 2),
+ model_dim=16,
+ ff_dim=32,
+ num_text_blocks=1,
+ num_visual_blocks=2,
+ axes_dims=(1, 1, 2),
+ visual_cond=True,
+ attention_type="regular",
+ )
+
+ return {
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "scheduler": scheduler,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = Image.new("RGB", (64, 64), color="red")
+
+ return {
+ "image": image,
+ "prompt": "a red square",
+ "height": 64,
+ "width": 64,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.0,
+ "generator": generator,
+ "output_type": "pt",
+ "max_sequence_length": 8,
+ }
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.resolutions = [(64, 64)]
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = pipe(**inputs)
+ image = output.image
+
+ self.assertEqual(image.shape, (1, 3, 64, 64))
+
+ @unittest.skip("TODO: Test does not work")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("TODO: revisit, Batch isnot yet supported in this pipeline")
+ def test_num_images_per_prompt(self):
+ pass
+
+ @unittest.skip("TODO: revisit, Batch isnot yet supported in this pipeline")
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skip("TODO: revisit, Batch isnot yet supported in this pipeline")
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip("TODO: revisit, not working")
+ def test_float16_inference(self):
+ pass
diff --git a/tests/pipelines/kandinsky5/test_kandinsky5_i2v.py b/tests/pipelines/kandinsky5/test_kandinsky5_i2v.py
new file mode 100644
index 000000000000..483c7b66e07b
--- /dev/null
+++ b/tests/pipelines/kandinsky5/test_kandinsky5_i2v.py
@@ -0,0 +1,211 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from PIL import Image
+from transformers import (
+ AutoProcessor,
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTokenizer,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLForConditionalGeneration,
+)
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ Kandinsky5I2VPipeline,
+ Kandinsky5Transformer3DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class Kandinsky5I2VPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Kandinsky5I2VPipeline
+
+ batch_params = ["prompt", "negative_prompt"]
+ params = frozenset(["image", "prompt", "height", "width", "num_frames", "num_inference_steps", "guidance_scale"])
+
+ required_optional_params = {
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ "max_sequence_length",
+ }
+ test_xformers_attention = False
+ supports_optional_components = True
+ supports_dduf = False
+ test_attention_slicing = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo(
+ act_fn="silu",
+ block_out_channels=[32, 64, 64],
+ down_block_types=[
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ "HunyuanVideoDownBlock3D",
+ ],
+ in_channels=3,
+ latent_channels=16,
+ layers_per_block=1,
+ mid_block_add_attention=False,
+ norm_num_groups=32,
+ out_channels=3,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ temporal_compression_ratio=4,
+ up_block_types=[
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ "HunyuanVideoUpBlock3D",
+ ],
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ qwen_hidden_size = 32
+ torch.manual_seed(0)
+ qwen_config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": qwen_hidden_size,
+ "intermediate_size": qwen_hidden_size,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [2, 2, 4],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": qwen_hidden_size,
+ "intermediate_size": qwen_hidden_size,
+ "num_heads": 2,
+ "out_hidden_size": qwen_hidden_size,
+ },
+ hidden_size=qwen_hidden_size,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(qwen_config)
+ tokenizer = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ clip_hidden_size = 16
+ torch.manual_seed(0)
+ clip_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=clip_hidden_size,
+ intermediate_size=16,
+ layer_norm_eps=1e-05,
+ num_attention_heads=2,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ projection_dim=clip_hidden_size,
+ )
+ text_encoder_2 = CLIPTextModel(clip_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ transformer = Kandinsky5Transformer3DModel(
+ in_visual_dim=16,
+ in_text_dim=qwen_hidden_size,
+ in_text_dim2=clip_hidden_size,
+ time_dim=16,
+ out_visual_dim=16,
+ patch_size=(1, 2, 2),
+ model_dim=16,
+ ff_dim=32,
+ num_text_blocks=1,
+ num_visual_blocks=2,
+ axes_dims=(1, 1, 2),
+ visual_cond=True,
+ attention_type="regular",
+ )
+
+ return {
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "scheduler": scheduler,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = Image.new("RGB", (32, 32), color="red")
+
+ return {
+ "image": image,
+ "prompt": "a red square",
+ "height": 32,
+ "width": 32,
+ "num_frames": 17,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.0,
+ "generator": generator,
+ "output_type": "pt",
+ "max_sequence_length": 8,
+ }
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = pipe(**inputs)
+ video = output.frames[0]
+
+ # 17 frames, RGB, 32×32
+ self.assertEqual(video.shape, (17, 3, 32, 32))
+
+ @unittest.skip("TODO:Test does not work")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("TODO: revisit")
+ def test_callback_inputs(self):
+ pass
+
+ @unittest.skip("TODO: revisit")
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/kandinsky5/test_kandinsky5_t2i.py b/tests/pipelines/kandinsky5/test_kandinsky5_t2i.py
new file mode 100644
index 000000000000..e961103906a2
--- /dev/null
+++ b/tests/pipelines/kandinsky5/test_kandinsky5_t2i.py
@@ -0,0 +1,207 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from transformers import (
+ AutoProcessor,
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTokenizer,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLForConditionalGeneration,
+)
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ Kandinsky5T2IPipeline,
+ Kandinsky5Transformer3DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class Kandinsky5T2IPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Kandinsky5T2IPipeline
+
+ batch_params = ["prompt", "negative_prompt"]
+ params = frozenset(["prompt", "height", "width", "num_inference_steps", "guidance_scale"])
+
+ required_optional_params = {
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ "max_sequence_length",
+ }
+ test_xformers_attention = False
+ supports_optional_components = True
+ supports_dduf = False
+ test_attention_slicing = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ act_fn="silu",
+ block_out_channels=[32, 64],
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ force_upcast=True,
+ in_channels=3,
+ latent_channels=16,
+ layers_per_block=1,
+ mid_block_add_attention=False,
+ norm_num_groups=32,
+ out_channels=3,
+ sample_size=128,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ use_post_quant_conv=False,
+ use_quant_conv=False,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ qwen_hidden_size = 32
+ torch.manual_seed(0)
+ qwen_config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": qwen_hidden_size,
+ "intermediate_size": qwen_hidden_size,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [2, 2, 4],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": qwen_hidden_size,
+ "intermediate_size": qwen_hidden_size,
+ "num_heads": 2,
+ "out_hidden_size": qwen_hidden_size,
+ },
+ hidden_size=qwen_hidden_size,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(qwen_config)
+ tokenizer = AutoProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ clip_hidden_size = 16
+ torch.manual_seed(0)
+ clip_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=clip_hidden_size,
+ intermediate_size=16,
+ layer_norm_eps=1e-05,
+ num_attention_heads=2,
+ num_hidden_layers=2,
+ pad_token_id=1,
+ vocab_size=1000,
+ projection_dim=clip_hidden_size,
+ )
+ text_encoder_2 = CLIPTextModel(clip_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ transformer = Kandinsky5Transformer3DModel(
+ in_visual_dim=16,
+ in_text_dim=qwen_hidden_size,
+ in_text_dim2=clip_hidden_size,
+ time_dim=16,
+ out_visual_dim=16,
+ patch_size=(1, 2, 2),
+ model_dim=16,
+ ff_dim=32,
+ num_text_blocks=1,
+ num_visual_blocks=2,
+ axes_dims=(1, 1, 2),
+ visual_cond=False,
+ attention_type="regular",
+ )
+
+ return {
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "scheduler": scheduler,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ return {
+ "prompt": "a red square",
+ "height": 64,
+ "width": 64,
+ "num_inference_steps": 2,
+ "guidance_scale": 4.0,
+ "generator": generator,
+ "output_type": "pt",
+ "max_sequence_length": 8,
+ }
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.resolutions = [(64, 64)]
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = pipe(**inputs)
+ image = output.image
+
+ self.assertEqual(image.shape, (1, 3, 16, 16))
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=5e-3)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("Only SDPA or NABLA (flex)")
+ def test_xformers_memory_efficient_attention(self):
+ pass
+
+ @unittest.skip("All encoders are needed")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("Meant for eiter FP32 or BF16 inference")
+ def test_float16_inference(self):
+ pass
diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py
index 218de2897e66..f1d4982d4d74 100644
--- a/tests/pipelines/kolors/test_kolors.py
+++ b/tests/pipelines/kolors/test_kolors.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,8 +25,8 @@
UNet2DConditionModel,
)
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -145,4 +145,4 @@ def test_save_load_float16(self):
super().test_save_load_float16(expected_max_diff=2e-1)
def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=5e-4)
+ self._test_inference_batch_single_identical(expected_max_diff=5e-3)
diff --git a/tests/pipelines/kolors/test_kolors_img2img.py b/tests/pipelines/kolors/test_kolors_img2img.py
index 89da95753a14..5a5d31a46456 100644
--- a/tests/pipelines/kolors/test_kolors_img2img.py
+++ b/tests/pipelines/kolors/test_kolors_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,11 +26,11 @@
UNet2DConditionModel,
)
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -155,6 +155,6 @@ def test_inference_batch_single_identical(self):
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=7e-2)
- @unittest.skip("Test not supported because kolors img2img doesn't take pooled embeds as inputs unline kolors t2i.")
+ @unittest.skip("Test not supported because kolors img2img doesn't take pooled embeds as inputs unlike kolors t2i.")
def test_encode_prompt_works_in_isolation(self):
pass
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
index 570fa8fadf39..c7666244b35f 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
@@ -12,14 +12,14 @@
LCMScheduler,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
index 88e31a97aac5..d8e7745b7805 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
@@ -13,7 +13,8 @@
LCMScheduler,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -22,7 +23,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py
index e751240e43b0..21c5bcf5a5b9 100644
--- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py
+++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,14 +21,15 @@
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_numpy,
nightly,
- require_torch_gpu,
+ require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -136,17 +137,17 @@ def test_inference_text2img(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class LDMTextToImagePipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.manual_seed(seed)
@@ -177,17 +178,17 @@ def test_ldm_default_ddim(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class LDMTextToImagePipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, dtype=torch.float32, seed=0):
generator = torch.manual_seed(seed)
diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
index 38ac6a46ccca..b2cbdb9f5b45 100644
--- a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
+++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,8 @@
from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel
from diffusers.utils import PIL_INTERPOLATION
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py
index 80d370647f57..a40d4bf8eede 100644
--- a/tests/pipelines/latte/test_latte.py
+++ b/tests/pipelines/latte/test_latte.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 Latte Team and HuggingFace Inc.
+# Copyright 2025 Latte Team and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,8 @@
PyramidAttentionBroadcastConfig,
)
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -39,7 +40,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
index 342561d4f5e9..6db20a464f19 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
@@ -28,7 +28,9 @@
LEditsPPPipelineStableDiffusion,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -244,7 +246,35 @@ def test_ledits_pp_editing(self):
output_slice = reconstruction[150:153, 140:143, -1]
output_slice = output_slice.flatten()
- expected_slice = np.array(
- [0.9453125, 0.93310547, 0.84521484, 0.94628906, 0.9111328, 0.80859375, 0.93847656, 0.9042969, 0.8144531]
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.9511719,
+ 0.94140625,
+ 0.87597656,
+ 0.9472656,
+ 0.9296875,
+ 0.8378906,
+ 0.94433594,
+ 0.91503906,
+ 0.8491211,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.9453125,
+ 0.93310547,
+ 0.84521484,
+ 0.94628906,
+ 0.9111328,
+ 0.80859375,
+ 0.93847656,
+ 0.9042969,
+ 0.8144531,
+ ]
+ ),
+ }
)
+ expected_slice = expected_slices.get_expectation()
assert np.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
index 75795a33422b..06c1ceb0cf5a 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
@@ -37,7 +37,7 @@
)
# from diffusers.image_processor import VaeImageProcessor
-from diffusers.utils.testing_utils import (
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py
index 4f72729fc9ce..aaf4161b51fb 100644
--- a/tests/pipelines/ltx/test_ltx.py
+++ b/tests/pipelines/ltx/test_ltx.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,16 +20,16 @@
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, to_np
+from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
-class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True
test_group_offloading = True
- def get_dummy_components(self):
+ def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
@@ -59,7 +59,7 @@ def get_dummy_components(self):
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
- num_layers=1,
+ num_layers=num_layers,
caption_channels=32,
)
diff --git a/tests/pipelines/ltx/test_ltx_condition.py b/tests/pipelines/ltx/test_ltx_condition.py
index dbb9a740b433..f5dfb0186209 100644
--- a/tests/pipelines/ltx/test_ltx_condition.py
+++ b/tests/pipelines/ltx/test_ltx_condition.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,8 +26,8 @@
LTXVideoTransformer3DModel,
)
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py
index 1c3e018a8a4b..2702993d4a59 100644
--- a/tests/pipelines/ltx/test_ltx_image2video.py
+++ b/tests/pipelines/ltx/test_ltx_image2video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,8 +25,8 @@
LTXImageToVideoPipeline,
LTXVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -109,7 +109,7 @@ def get_dummy_inputs(self, device, seed=0):
else:
generator = torch.Generator(device=device).manual_seed(seed)
- image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
+ image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"image": image,
@@ -142,7 +142,7 @@ def test_inference(self):
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
expected_video = torch.randn(9, 3, 32, 32)
- max_diff = np.abs(generated_video - expected_video).max()
+ max_diff = torch.amax(torch.abs(generated_video - expected_video))
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
diff --git a/tests/pipelines/ltx/test_ltx_latent_upsample.py b/tests/pipelines/ltx/test_ltx_latent_upsample.py
new file mode 100644
index 000000000000..0044a85c644b
--- /dev/null
+++ b/tests/pipelines/ltx/test_ltx_latent_upsample.py
@@ -0,0 +1,159 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+
+from diffusers import AutoencoderKLLTXVideo, LTXLatentUpsamplePipeline
+from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
+
+from ...testing_utils import enable_full_determinism
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class LTXLatentUpsamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = LTXLatentUpsamplePipeline
+ params = {"video", "generator"}
+ batch_params = {"video", "generator"}
+ required_optional_params = frozenset(["generator", "latents", "return_dict"])
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLLTXVideo(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=8,
+ block_out_channels=(8, 8, 8, 8),
+ decoder_block_out_channels=(8, 8, 8, 8),
+ layers_per_block=(1, 1, 1, 1, 1),
+ decoder_layers_per_block=(1, 1, 1, 1, 1),
+ spatio_temporal_scaling=(True, True, False, False),
+ decoder_spatio_temporal_scaling=(True, True, False, False),
+ decoder_inject_noise=(False, False, False, False, False),
+ upsample_residual=(False, False, False, False),
+ upsample_factor=(1, 1, 1, 1),
+ timestep_conditioning=False,
+ patch_size=1,
+ patch_size_t=1,
+ encoder_causal=True,
+ decoder_causal=False,
+ )
+ vae.use_framewise_encoding = False
+ vae.use_framewise_decoding = False
+
+ torch.manual_seed(0)
+ latent_upsampler = LTXLatentUpsamplerModel(
+ in_channels=8,
+ mid_channels=32,
+ num_blocks_per_stage=1,
+ dims=3,
+ spatial_upsample=True,
+ temporal_upsample=False,
+ )
+
+ components = {
+ "vae": vae,
+ "latent_upsampler": latent_upsampler,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ video = torch.randn((5, 3, 32, 32), generator=generator, device=device)
+
+ inputs = {
+ "video": video,
+ "generator": generator,
+ "height": 16,
+ "width": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (5, 3, 32, 32))
+ expected_video = torch.randn(5, 3, 32, 32)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.25):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @unittest.skip("Test is not applicable.")
+ def test_callback_inputs(self):
+ pass
+
+ @unittest.skip("Test is not applicable.")
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ pass
+
+ @unittest.skip("Test is not applicable.")
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip("Test is not applicable.")
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py
index 0c1fe8eb2fcd..d2c114825d34 100644
--- a/tests/pipelines/lumina/test_lumina_nextdit.py
+++ b/tests/pipelines/lumina/test_lumina_nextdit.py
@@ -10,16 +10,15 @@
FlowMatchEulerDiscreteScheduler,
LuminaNextDiT2DModel,
LuminaPipeline,
- LuminaText2ImgPipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -105,12 +104,6 @@ def get_dummy_inputs(self, device, seed=0):
def test_xformers_attention_forwardGenerator_pass(self):
pass
- def test_deprecation_raises_warning(self):
- with self.assertWarns(FutureWarning) as warning:
- _ = LuminaText2ImgPipeline(**self.get_dummy_components()).to(torch_device)
- warning_message = str(warning.warnings[0].message)
- assert "renamed to `LuminaPipeline`" in warning_message
-
@slow
@require_torch_accelerator
diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py
index 33fc870bcd34..d6d21b72a4ce 100644
--- a/tests/pipelines/lumina2/test_pipeline_lumina2.py
+++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py
@@ -7,10 +7,8 @@
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
Lumina2Pipeline,
- Lumina2Text2ImgPipeline,
Lumina2Transformer2DModel,
)
-from diffusers.utils.testing_utils import torch_device
from ..test_pipelines_common import PipelineTesterMixin
@@ -117,9 +115,3 @@ def get_dummy_inputs(self, device, seed=0):
"output_type": "np",
}
return inputs
-
- def test_deprecation_raises_warning(self):
- with self.assertWarns(FutureWarning) as warning:
- _ = Lumina2Text2ImgPipeline(**self.get_dummy_components()).to(torch_device)
- warning_message = str(warning.warnings[0].message)
- assert "renamed to `Lumina2Pipeline`" in warning_message
diff --git a/tests/pipelines/marigold/test_marigold_depth.py b/tests/pipelines/marigold/test_marigold_depth.py
index 13f9a421861b..3c853059921b 100644
--- a/tests/pipelines/marigold/test_marigold_depth.py
+++ b/tests/pipelines/marigold/test_marigold_depth.py
@@ -31,7 +31,9 @@
MarigoldDepthPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -41,7 +43,6 @@
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -356,7 +357,7 @@ def test_marigold_depth_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=False,
device=torch_device,
@@ -369,7 +370,7 @@ def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -382,7 +383,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -395,12 +396,23 @@ def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1(self):
+ # fmt: off
+ expected_slices = Expectations(
+ {
+ ("cuda", 7): np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
+ ("xpu", 3): np.array([0.1084, 0.1096, 0.1108, 0.1080, 0.1083, 0.1080,
+ 0.1085, 0.1057, 0.0996]),
+ }
+ )
+ expected_slice = expected_slices.get_expectation()
+ # fmt: on
+
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
+ expected_slice=expected_slice,
num_inference_steps=2,
processing_resolution=768,
ensemble_size=1,
@@ -408,7 +420,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -421,7 +433,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -435,7 +447,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -449,7 +461,7 @@ def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
diff --git a/tests/pipelines/marigold/test_marigold_intrinsics.py b/tests/pipelines/marigold/test_marigold_intrinsics.py
index b24e686a4dfe..7db14b67cec9 100644
--- a/tests/pipelines/marigold/test_marigold_intrinsics.py
+++ b/tests/pipelines/marigold/test_marigold_intrinsics.py
@@ -32,15 +32,17 @@
MarigoldIntrinsicsPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -395,17 +397,17 @@ def test_marigold_depth_dummy_no_processing_resolution(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def _test_marigold_intrinsics(
self,
@@ -415,7 +417,7 @@ def _test_marigold_intrinsics(
expected_slice: np.ndarray = None,
model_id: str = "prs-eth/marigold-iid-appearance-v1-1",
image_url: str = "https://marigoldmonodepth.github.io/images/einstein.jpg",
- atol: float = 1e-4,
+ atol: float = 1e-3,
**pipe_kwargs,
):
from_pretrained_kwargs = {}
@@ -424,7 +426,7 @@ def _test_marigold_intrinsics(
from_pretrained_kwargs["torch_dtype"] = torch.float16
pipe = MarigoldIntrinsicsPipeline.from_pretrained(model_id, **from_pretrained_kwargs)
- if device == "cuda":
+ if device in ["cuda", "xpu"]:
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
@@ -464,10 +466,10 @@ def test_marigold_intrinsics_einstein_f32_cpu_G0_S1_P32_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_intrinsics_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
+ def test_marigold_intrinsics_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=False,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.62127, 0.61906, 0.61687, 0.61946, 0.61903, 0.61961, 0.61808, 0.62099, 0.62894]),
num_inference_steps=1,
@@ -477,10 +479,10 @@ def test_marigold_intrinsics_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
+ def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.62109, 0.61914, 0.61719, 0.61963, 0.61914, 0.61963, 0.61816, 0.62109, 0.62891]),
num_inference_steps=1,
@@ -490,10 +492,10 @@ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_intrinsics_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
+ def test_marigold_intrinsics_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=2024,
expected_slice=np.array([0.64111, 0.63916, 0.63623, 0.63965, 0.63916, 0.63965, 0.6377, 0.64062, 0.64941]),
num_inference_steps=1,
@@ -503,10 +505,10 @@ def test_marigold_intrinsics_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_intrinsics_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
+ def test_marigold_intrinsics_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.60254, 0.60059, 0.59961, 0.60156, 0.60107, 0.60205, 0.60254, 0.60449, 0.61133]),
num_inference_steps=2,
@@ -516,10 +518,10 @@ def test_marigold_intrinsics_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
+ def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1(self):
self._test_marigold_intrinsics(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.64551, 0.64453, 0.64404, 0.64502, 0.64844, 0.65039, 0.64502, 0.65039, 0.65332]),
num_inference_steps=1,
@@ -529,12 +531,42 @@ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
+ def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.62655,
+ 0.62477,
+ 0.62161,
+ 0.62452,
+ 0.62454,
+ 0.62454,
+ 0.62255,
+ 0.62647,
+ 0.63379,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.61572,
+ 0.1377,
+ 0.61182,
+ 0.61426,
+ 0.61377,
+ 0.61426,
+ 0.61279,
+ 0.61572,
+ 0.62354,
+ ]
+ ),
+ }
+ )
self._test_marigold_intrinsics(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.61572, 0.61377, 0.61182, 0.61426, 0.61377, 0.61426, 0.61279, 0.61572, 0.62354]),
+ expected_slice=expected_slices.get_expectation(),
num_inference_steps=1,
processing_resolution=768,
ensemble_size=3,
@@ -543,12 +575,42 @@ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
match_input_resolution=True,
)
- def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
+ def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.62988,
+ 0.62792,
+ 0.62548,
+ 0.62841,
+ 0.62792,
+ 0.62792,
+ 0.62646,
+ 0.62939,
+ 0.63721,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.61914,
+ 0.6167,
+ 0.61475,
+ 0.61719,
+ 0.61719,
+ 0.61768,
+ 0.61572,
+ 0.61914,
+ 0.62695,
+ ]
+ ),
+ }
+ )
self._test_marigold_intrinsics(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.61914, 0.6167, 0.61475, 0.61719, 0.61719, 0.61768, 0.61572, 0.61914, 0.62695]),
+ expected_slice=expected_slices.get_expectation(),
num_inference_steps=1,
processing_resolution=768,
ensemble_size=4,
@@ -557,10 +619,10 @@ def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
match_input_resolution=True,
)
- def test_marigold_intrinsics_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
+ def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0(self):
self._test_marigold_intrinsics(
is_fp16=True,
- device="cuda",
+ device=torch_device,
generator_seed=0,
expected_slice=np.array([0.65332, 0.64697, 0.64648, 0.64844, 0.64697, 0.64111, 0.64941, 0.64209, 0.65332]),
num_inference_steps=1,
diff --git a/tests/pipelines/marigold/test_marigold_normals.py b/tests/pipelines/marigold/test_marigold_normals.py
index 1797f99b213b..108163bf22ec 100644
--- a/tests/pipelines/marigold/test_marigold_normals.py
+++ b/tests/pipelines/marigold/test_marigold_normals.py
@@ -31,7 +31,8 @@
MarigoldNormalsPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -40,7 +41,6 @@
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py
index ea2d015af52a..5615720a9343 100644
--- a/tests/pipelines/mochi/test_mochi.py
+++ b/tests/pipelines/mochi/test_mochi.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,29 +17,30 @@
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
- require_big_gpu_with_torch_cuda,
- require_torch_gpu,
+ require_big_accelerator,
+ require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
+from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
-class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
+class MochiPipelineFastTests(
+ PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
+):
pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -266,9 +267,8 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
@nightly
-@require_torch_gpu
-@require_big_gpu_with_torch_cuda
-@pytest.mark.big_gpu_with_torch_cuda
+@require_torch_accelerator
+@require_big_accelerator
class MochiPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
@@ -302,5 +302,5 @@ def test_mochi(self):
video = videos[0]
expected_video = torch.randn(1, 19, 480, 848, 3).numpy()
- max_diff = numpy_cosine_similarity_distance(video, expected_video)
+ max_diff = numpy_cosine_similarity_distance(video.cpu(), expected_video)
assert max_diff < 1e-3, f"Max diff is too high. got {video}"
diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py
deleted file mode 100644
index bdd536b6ff86..000000000000
--- a/tests/pipelines/musicldm/test_musicldm.py
+++ /dev/null
@@ -1,472 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- ClapAudioConfig,
- ClapConfig,
- ClapFeatureExtractor,
- ClapModel,
- ClapTextConfig,
- RobertaTokenizer,
- SpeechT5HifiGan,
- SpeechT5HifiGanConfig,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- LMSDiscreteScheduler,
- MusicLDMPipeline,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
-
-from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = MusicLDMPipeline
- params = TEXT_TO_AUDIO_PARAMS
- batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "num_waveforms_per_prompt",
- "generator",
- "latents",
- "output_type",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=(32, 64),
- class_embed_type="simple_projection",
- projection_class_embeddings_input_dim=32,
- class_embeddings_concat=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=1,
- out_channels=1,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_branch_config = ClapTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=16,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- )
- audio_branch_config = ClapAudioConfig(
- spec_size=64,
- window_size=4,
- num_mel_bins=64,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- depths=[2, 2],
- num_attention_heads=[2, 2],
- num_hidden_layers=2,
- hidden_size=192,
- patch_size=2,
- patch_stride=2,
- patch_embed_input_channels=4,
- )
- text_encoder_config = ClapConfig.from_text_audio_configs(
- text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=32
- )
- text_encoder = ClapModel(text_encoder_config)
- tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
- feature_extractor = ClapFeatureExtractor.from_pretrained(
- "hf-internal-testing/tiny-random-ClapModel", hop_length=7900
- )
-
- torch.manual_seed(0)
- vocoder_config = SpeechT5HifiGanConfig(
- model_in_dim=8,
- sampling_rate=16000,
- upsample_initial_channel=16,
- upsample_rates=[2, 2],
- upsample_kernel_sizes=[4, 4],
- resblock_kernel_sizes=[3, 7],
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
- normalize_before=False,
- )
-
- vocoder = SpeechT5HifiGan(vocoder_config)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": feature_extractor,
- "vocoder": vocoder,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- }
- return inputs
-
- def test_musicldm_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = musicldm_pipe(**inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0027, -0.0036, -0.0037, -0.0020, -0.0035, -0.0019, -0.0037, -0.0020, -0.0038, -0.0019]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-4
-
- def test_musicldm_prompt_embeds(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = musicldm_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=musicldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = musicldm_pipe.text_encoder.get_text_features(text_inputs)
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_musicldm_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- embeds = []
- for p in [prompt, negative_prompt]:
- text_inputs = musicldm_pipe.tokenizer(
- p,
- padding="max_length",
- max_length=musicldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- text_embeds = musicldm_pipe.text_encoder.get_text_features(
- text_inputs,
- )
- embeds.append(text_embeds)
-
- inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_musicldm_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "egg cracking"
- output = musicldm_pipe(**inputs, negative_prompt=negative_prompt)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0027, -0.0036, -0.0037, -0.0019, -0.0035, -0.0018, -0.0037, -0.0021, -0.0038, -0.0018]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-4
-
- def test_musicldm_num_waveforms_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A hammer hitting a wooden surface"
-
- # test num_waveforms_per_prompt=1 (default)
- audios = musicldm_pipe(prompt, num_inference_steps=2).audios
-
- assert audios.shape == (1, 256)
-
- # test num_waveforms_per_prompt=1 (default) for batch of prompts
- batch_size = 2
- audios = musicldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
-
- assert audios.shape == (batch_size, 256)
-
- # test num_waveforms_per_prompt for single prompt
- num_waveforms_per_prompt = 2
- audios = musicldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
-
- assert audios.shape == (num_waveforms_per_prompt, 256)
-
- # test num_waveforms_per_prompt for batch of prompts
- batch_size = 2
- audios = musicldm_pipe(
- [prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
- ).audios
-
- assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
-
- def test_musicldm_audio_length_in_s(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
- vocoder_sampling_rate = musicldm_pipe.vocoder.config.sampling_rate
-
- inputs = self.get_dummy_inputs(device)
- output = musicldm_pipe(audio_length_in_s=0.016, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.016
-
- output = musicldm_pipe(audio_length_in_s=0.032, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.032
-
- def test_musicldm_vocoder_model_in_dim(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = ["hey"]
-
- output = musicldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- assert audio_shape == (1, 256)
-
- config = musicldm_pipe.vocoder.config
- config.model_in_dim *= 2
- musicldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
- output = musicldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- # waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
- assert audio_shape == (1, 256)
-
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # The method component.dtype returns the dtype of the first parameter registered in the model, not the
- # dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
- model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
-
- # Without the logit scale parameters, everything is float32
- model_dtypes.pop("text_encoder")
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
-
- # the CLAP sub-models are float32
- model_dtypes["clap_text_branch"] = components["text_encoder"].text_model.dtype
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
-
- # Once we send to fp16, all params are in half-precision, including the logit scale
- pipe.to(dtype=torch.float16)
- model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
-
-
-@nightly
-@require_torch_gpu
-class MusicLDMPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_musicldm(self):
- musicldm_pipe = MusicLDMPipeline.from_pretrained("cvssp/musicldm")
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 25
- audio = musicldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81952
-
- # check the portion of the generated audio with the largest dynamic range (reduces flakiness)
- audio_slice = audio[8680:8690]
- expected_slice = np.array(
- [-0.1042, -0.1068, -0.1235, -0.1387, -0.1428, -0.136, -0.1213, -0.1097, -0.0967, -0.0945]
- )
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-3
-
- def test_musicldm_lms(self):
- musicldm_pipe = MusicLDMPipeline.from_pretrained("cvssp/musicldm")
- musicldm_pipe.scheduler = LMSDiscreteScheduler.from_config(musicldm_pipe.scheduler.config)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- audio = musicldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81952
-
- # check the portion of the generated audio with the largest dynamic range (reduces flakiness)
- audio_slice = audio[58020:58030]
- expected_slice = np.array([0.3592, 0.3477, 0.4084, 0.4665, 0.5048, 0.5891, 0.6461, 0.5579, 0.4595, 0.4403])
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-3
diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py
index 2f9c4d4e3f8e..1a758b705042 100644
--- a/tests/pipelines/omnigen/test_pipeline_omnigen.py
+++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py
@@ -6,13 +6,15 @@
from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
+ backend_empty_cache,
numpy_cosine_similarity_distance,
- require_torch_gpu,
+ require_torch_accelerator,
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -20,7 +22,7 @@ class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = OmniGenPipeline
params = frozenset(["prompt", "guidance_scale"])
batch_params = frozenset(["prompt"])
-
+ test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
@@ -87,7 +89,7 @@ def test_inference(self):
@slow
-@require_torch_gpu
+@require_torch_accelerator
class OmniGenPipelineSlowTests(unittest.TestCase):
pipeline_class = OmniGenPipeline
repo_id = "shitao/OmniGen-v1-diffusers"
@@ -95,12 +97,12 @@ class OmniGenPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
@@ -125,21 +127,56 @@ def test_omnigen_inference(self):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- [0.1783447, 0.16772744, 0.14339337],
- [0.17066911, 0.15521264, 0.13757327],
- [0.17072496, 0.15531206, 0.13524258],
- [0.16746324, 0.1564025, 0.13794944],
- [0.16490817, 0.15258026, 0.13697758],
- [0.16971767, 0.15826806, 0.13928896],
- [0.16782972, 0.15547255, 0.13783783],
- [0.16464645, 0.15281534, 0.13522372],
- [0.16535294, 0.15301755, 0.13526791],
- [0.16365296, 0.15092957, 0.13443318],
- ],
- dtype=np.float32,
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ [0.05859375, 0.05859375, 0.04492188],
+ [0.04882812, 0.04101562, 0.03320312],
+ [0.04882812, 0.04296875, 0.03125],
+ [0.04296875, 0.0390625, 0.03320312],
+ [0.04296875, 0.03710938, 0.03125],
+ [0.04492188, 0.0390625, 0.03320312],
+ [0.04296875, 0.03710938, 0.03125],
+ [0.04101562, 0.03710938, 0.02734375],
+ [0.04101562, 0.03515625, 0.02734375],
+ [0.04101562, 0.03515625, 0.02929688],
+ ],
+ dtype=np.float32,
+ ),
+ ("cuda", 7): np.array(
+ [
+ [0.1783447, 0.16772744, 0.14339337],
+ [0.17066911, 0.15521264, 0.13757327],
+ [0.17072496, 0.15531206, 0.13524258],
+ [0.16746324, 0.1564025, 0.13794944],
+ [0.16490817, 0.15258026, 0.13697758],
+ [0.16971767, 0.15826806, 0.13928896],
+ [0.16782972, 0.15547255, 0.13783783],
+ [0.16464645, 0.15281534, 0.13522372],
+ [0.16535294, 0.15301755, 0.13526791],
+ [0.16365296, 0.15092957, 0.13443318],
+ ],
+ dtype=np.float32,
+ ),
+ ("cuda", 8): np.array(
+ [
+ [0.0546875, 0.05664062, 0.04296875],
+ [0.046875, 0.04101562, 0.03320312],
+ [0.05078125, 0.04296875, 0.03125],
+ [0.04296875, 0.04101562, 0.03320312],
+ [0.0390625, 0.03710938, 0.02929688],
+ [0.04296875, 0.03710938, 0.03125],
+ [0.0390625, 0.03710938, 0.02929688],
+ [0.0390625, 0.03710938, 0.02734375],
+ [0.0390625, 0.03320312, 0.02734375],
+ [0.0390625, 0.03320312, 0.02734375],
+ ],
+ dtype=np.float32,
+ ),
+ }
)
+ expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
diff --git a/tests/pipelines/text_to_video_synthesis/__init__.py b/tests/pipelines/ovis_image/__init__.py
similarity index 100%
rename from tests/pipelines/text_to_video_synthesis/__init__.py
rename to tests/pipelines/ovis_image/__init__.py
diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py
index 6fa96275406f..b1cbd82d7679 100644
--- a/tests/pipelines/pag/test_pag_animatediff.py
+++ b/tests/pipelines/pag/test_pag_animatediff.py
@@ -19,8 +19,8 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
@@ -450,9 +450,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).frames[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd.py b/tests/pipelines/pag/test_pag_controlnet_sd.py
index ee97b0507a34..36d5ae100a58 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,9 +28,9 @@
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.torch_utils import randn_tensor
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -169,9 +169,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
index 25ef5d253d68..948381f9769e 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,9 +32,9 @@
StableDiffusionControlNetPAGInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from diffusers.utils.torch_utils import randn_tensor
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -165,9 +165,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__calss__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl.py b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
index 0588e26286a8..51b00f6932bc 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,9 +28,9 @@
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism
from diffusers.utils.torch_utils import randn_tensor
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -187,9 +187,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
index 63c7d9fbee2d..3c1088adbcf2 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,8 +29,8 @@
StableDiffusionXLControlNetPAGImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor
+from ...testing_utils import enable_full_determinism, floats_tensor
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -189,9 +189,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py
index 31cd9aa666de..f268a614f85c 100644
--- a/tests/pipelines/pag/test_pag_hunyuan_dit.py
+++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -28,8 +28,8 @@
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -177,15 +177,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -198,9 +198,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py
index 9a4f1daa2c05..1bbb4e79e4bc 100644
--- a/tests/pipelines/pag/test_pag_kolors.py
+++ b/tests/pipelines/pag/test_pag_kolors.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,8 +27,8 @@
UNet2DConditionModel,
)
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -140,9 +140,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py
index 63f42416dbca..c04ebad08fdc 100644
--- a/tests/pipelines/pag/test_pag_pixart_sigma.py
+++ b/tests/pipelines/pag/test_pag_pixart_sigma.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,12 +30,12 @@
PixArtTransformer2DModel,
)
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
enable_full_determinism,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
@@ -120,9 +120,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe.__class__.__name__}."
+ )
out = pipe(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -254,7 +254,7 @@ def test_attention_slicing_forward_pass(
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
- # Because we have `pag_applied_layers` we cannot direcly apply
+ # Because we have `pag_applied_layers` we cannot directly apply
# `set_default_attn_processor`
def test_dict_tuple_outputs_equivalent(self, expected_slice=None, expected_max_difference=1e-4):
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py
index a2c657297860..5408595c729d 100644
--- a/tests/pipelines/pag/test_pag_sana.py
+++ b/tests/pipelines/pag/test_pag_sana.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,8 +26,8 @@
SanaPipeline,
SanaTransformer2DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -268,9 +268,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py
index d4cf00b034ff..064815d13693 100644
--- a/tests/pipelines/pag/test_pag_sd.py
+++ b/tests/pipelines/pag/test_pag_sd.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,14 +29,14 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -154,9 +154,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -328,9 +328,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -345,6 +345,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd3.py b/tests/pipelines/pag/test_pag_sd3.py
index 41ff0c3c09f4..26e6ca099286 100644
--- a/tests/pipelines/pag/test_pag_sd3.py
+++ b/tests/pipelines/pag/test_pag_sd3.py
@@ -12,10 +12,10 @@
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
torch_device,
)
-
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -170,9 +170,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -186,15 +186,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pag_disable_enable(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -207,9 +207,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
diff --git a/tests/pipelines/pag/test_pag_sd3_img2img.py b/tests/pipelines/pag/test_pag_sd3_img2img.py
index 2fe988929185..19a36e283de4 100644
--- a/tests/pipelines/pag/test_pag_sd3_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd3_img2img.py
@@ -15,7 +15,8 @@
StableDiffusion3Img2ImgPipeline,
StableDiffusion3PAGImg2ImgPipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -24,7 +25,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -149,9 +149,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
components = self.get_dummy_components()
@@ -254,9 +254,9 @@ def test_pag_cfg(self):
0.17822266,
]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(
@@ -272,6 +272,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.1508789, 0.16210938, 0.17138672, 0.16210938, 0.17089844, 0.16137695, 0.16235352, 0.16430664, 0.16455078]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py
index d000493d6bd1..0b440d5ec9fc 100644
--- a/tests/pipelines/pag/test_pag_sd_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,8 @@
StableDiffusionPAGImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -40,7 +41,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -161,9 +161,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -267,9 +267,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -285,6 +285,6 @@ def test_pag_uncond(self):
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py
index 06682c111d37..754158bbf138 100644
--- a/tests/pipelines/pag/test_pag_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sd_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,7 +29,8 @@
StableDiffusionPAGInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -38,7 +39,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -255,7 +255,7 @@ def test_encode_prompt_works_in_isolation(self):
@require_torch_accelerator
class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusionPAGInpaintPipeline
- repo_id = "runwayml/stable-diffusion-v1-5"
+ repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
def setUp(self):
super().setUp()
@@ -302,9 +302,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -319,6 +319,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.3876953, 0.40356445, 0.4934082, 0.39697266, 0.41674805, 0.41015625, 0.375, 0.36914062, 0.40649414]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py
index b35b2b1d2f7e..cca5c61651b3 100644
--- a/tests/pipelines/pag/test_pag_sdxl.py
+++ b/tests/pipelines/pag/test_pag_sdxl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,14 +29,14 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -167,9 +167,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -331,9 +331,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.3123679, 0.31725878, 0.32026544, 0.327533, 0.3266391, 0.3303998, 0.33544615, 0.34181812, 0.34102726]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -348,6 +348,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.47400922, 0.48650584, 0.4839625, 0.4724013, 0.4890427, 0.49544555, 0.51707107, 0.54299414, 0.5224372]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py
index c94a6836de7f..d311500d3ca7 100644
--- a/tests/pipelines/pag/test_pag_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,7 +38,8 @@
StableDiffusionXLPAGImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -47,7 +48,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -215,9 +215,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -316,9 +316,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.20301354, 0.21078318, 0.2021082, 0.20277798, 0.20681083, 0.19562206, 0.20121682, 0.21562952, 0.21277016]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -333,6 +333,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.21303111, 0.22188407, 0.2124992, 0.21365267, 0.18823743, 0.17569828, 0.21113116, 0.19419771, 0.18919235]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
index cca5292288b0..00a07582e205 100644
--- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,7 +39,8 @@
StableDiffusionXLPAGInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -48,7 +49,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -220,9 +220,9 @@ def test_pag_disable_enable(self):
inputs = self.get_dummy_inputs(device)
del inputs["pag_scale"]
- assert (
- "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
- ), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ assert "pag_scale" not in inspect.signature(pipe_sd.__call__).parameters, (
+ f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
+ )
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
# pag disabled with pag_scale=0.0
@@ -322,9 +322,9 @@ def test_pag_cfg(self):
expected_slice = np.array(
[0.41385046, 0.39608297, 0.4360491, 0.26872507, 0.32187328, 0.4242474, 0.2603805, 0.34167895, 0.46561807]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
def test_pag_uncond(self):
pipeline = AutoPipelineForInpainting.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
@@ -339,6 +339,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.41597816, 0.39302617, 0.44287828, 0.2687074, 0.28315824, 0.40582314, 0.20877528, 0.2380802, 0.39447647]
)
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
- ), f"output is different from expected, {image_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3, (
+ f"output is different from expected, {image_slice.flatten()}"
+ )
diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py
deleted file mode 100644
index 6b668de2762a..000000000000
--- a/tests/pipelines/paint_by_example/test_paint_by_example.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPImageProcessor, CLIPVisionConfig
-
-from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- require_torch_gpu,
- torch_device,
-)
-
-from ..pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = PaintByExamplePipeline
- params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
- batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
- image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=9,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = PNDMScheduler(skip_prk_steps=True)
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- config = CLIPVisionConfig(
- hidden_size=32,
- projection_dim=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- image_size=32,
- patch_size=4,
- )
- image_encoder = PaintByExampleImageEncoder(config, proj_size=32)
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "image_encoder": image_encoder,
- "safety_checker": None,
- "feature_extractor": feature_extractor,
- }
- return components
-
- def convert_to_pt(self, image):
- image = np.array(image.convert("RGB"))
- image = image[None].transpose(0, 3, 1, 2)
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
- return image
-
- def get_dummy_inputs(self, device="cpu", seed=0):
- # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
- mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
- example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "example_image": example_image,
- "image": init_image,
- "mask_image": mask_image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_paint_by_example_inpaint(self):
- components = self.get_dummy_components()
-
- # make sure here that pndm scheduler skips prk
- pipe = PaintByExamplePipeline(**components)
- pipe = pipe.to("cpu")
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs()
- output = pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4686, 0.5687, 0.4007, 0.5218, 0.5741, 0.4482, 0.4940, 0.4629, 0.4503])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_paint_by_example_image_tensor(self):
- device = "cpu"
- inputs = self.get_dummy_inputs()
- inputs.pop("mask_image")
- image = self.convert_to_pt(inputs.pop("image"))
- mask_image = image.clamp(0, 1) / 2
-
- # make sure here that pndm scheduler skips prk
- pipe = PaintByExamplePipeline(**self.get_dummy_components())
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(image=image, mask_image=mask_image[:, 0], **inputs)
- out_1 = output.images
-
- image = image.cpu().permute(0, 2, 3, 1)[0]
- mask_image = mask_image.cpu().permute(0, 2, 3, 1)[0]
-
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB")
-
- output = pipe(**self.get_dummy_inputs())
- out_2 = output.images
-
- assert out_1.shape == (1, 64, 64, 3)
- assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
-
-@nightly
-@require_torch_gpu
-class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_paint_by_example(self):
- # make sure here that pndm scheduler skips prk
- init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/dog_in_bucket.png"
- )
- mask_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/mask.png"
- )
- example_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/panda.jpg"
- )
-
- pipe = PaintByExamplePipeline.from_pretrained("Fantasy-Studio/Paint-by-Example")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(321)
- output = pipe(
- image=init_image,
- mask_image=mask_image,
- example_image=example_image,
- generator=generator,
- guidance_scale=5.0,
- num_inference_steps=50,
- output_type="np",
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.4834, 0.4811, 0.4874, 0.5122, 0.5081, 0.5144, 0.5291, 0.5290, 0.5374])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py
deleted file mode 100644
index 1156bf32dafa..000000000000
--- a/tests/pipelines/pia/test_pia.py
+++ /dev/null
@@ -1,448 +0,0 @@
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-import diffusers
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- DPMSolverMultistepScheduler,
- LCMScheduler,
- MotionAdapter,
- PIAPipeline,
- StableDiffusionPipeline,
- UNet2DConditionModel,
- UNetMotionModel,
-)
-from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import floats_tensor, require_accelerator, torch_device
-
-from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
- pipeline_class = PIAPipeline
- params = frozenset(
- [
- "prompt",
- "height",
- "width",
- "guidance_scale",
- "negative_prompt",
- "prompt_embeds",
- "negative_prompt_embeds",
- "cross_attention_kwargs",
- ]
- )
- batch_params = frozenset(["prompt", "image", "generator"])
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback_on_step_end",
- "callback_on_step_end_tensor_inputs",
- ]
- )
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- cross_attention_dim = 8
- block_out_channels = (8, 8)
-
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=block_out_channels,
- layers_per_block=2,
- sample_size=8,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="linear",
- clip_sample=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=block_out_channels,
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=cross_attention_dim,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- torch.manual_seed(0)
- motion_adapter = MotionAdapter(
- block_out_channels=block_out_channels,
- motion_layers_per_block=2,
- motion_norm_num_groups=2,
- motion_num_attention_heads=4,
- conv_in_channels=9,
- )
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "motion_adapter": motion_adapter,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- image = floats_tensor((1, 3, 8, 8), rng=random.Random(seed)).to(device)
- inputs = {
- "image": image,
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 7.5,
- "output_type": "pt",
- }
- return inputs
-
- def test_from_pipe_consistent_config(self):
- assert self.original_pipeline_class == StableDiffusionPipeline
- original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
- original_kwargs = {"requires_safety_checker": False}
-
- # create original_pipeline_class(sd)
- pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
-
- # original_pipeline_class(sd) -> pipeline_class
- pipe_components = self.get_dummy_components()
- pipe_additional_components = {}
- for name, component in pipe_components.items():
- if name not in pipe_original.components:
- pipe_additional_components[name] = component
-
- pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
-
- # pipeline_class -> original_pipeline_class(sd)
- original_pipe_additional_components = {}
- for name, component in pipe_original.components.items():
- if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
- original_pipe_additional_components[name] = component
-
- pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
-
- # compare the config
- original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
- original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
- assert original_config_2 == original_config
-
- def test_motion_unet_loading(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
-
- assert isinstance(pipe.unet, UNetMotionModel)
-
- def test_ip_adapter(self):
- expected_pipe_slice = None
-
- if torch_device == "cpu":
- expected_pipe_slice = np.array(
- [
- 0.5475,
- 0.5769,
- 0.4873,
- 0.5064,
- 0.4445,
- 0.5876,
- 0.5453,
- 0.4102,
- 0.5247,
- 0.5370,
- 0.3406,
- 0.4322,
- 0.3991,
- 0.3756,
- 0.5438,
- 0.4780,
- 0.5087,
- 0.5248,
- 0.6243,
- 0.5506,
- 0.3491,
- 0.5440,
- 0.6111,
- 0.5122,
- 0.5326,
- 0.5180,
- 0.5538,
- ]
- )
- return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.5476, 0.4092, 0.5289, 0.4755, 0.5092, 0.5186, 0.5403, 0.5287, 0.5467])
- return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- @unittest.skip("Attention slicing is not enabled in this pipeline")
- def test_attention_slicing_forward_pass(self):
- pass
-
- def test_inference_batch_single_identical(
- self,
- batch_size=2,
- expected_max_diff=1e-4,
- additional_params_copy_to_batched_inputs=["num_inference_steps"],
- ):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- for components in pipe.components.values():
- if hasattr(components, "set_default_attn_processor"):
- components.set_default_attn_processor()
-
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- inputs = self.get_dummy_inputs(torch_device)
- # Reset generator in case it is has been used in self.get_dummy_inputs
- inputs["generator"] = self.get_generator(0)
-
- logger = logging.get_logger(pipe.__module__)
- logger.setLevel(level=diffusers.logging.FATAL)
-
- # batchify inputs
- batched_inputs = {}
- batched_inputs.update(inputs)
-
- for name in self.batch_params:
- if name not in inputs:
- continue
-
- value = inputs[name]
- if name == "prompt":
- len_prompt = len(value)
- batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
- batched_inputs[name][-1] = 100 * "very long"
-
- else:
- batched_inputs[name] = batch_size * [value]
-
- if "generator" in inputs:
- batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
-
- if "batch_size" in inputs:
- batched_inputs["batch_size"] = batch_size
-
- for arg in additional_params_copy_to_batched_inputs:
- batched_inputs[arg] = inputs[arg]
-
- output = pipe(**inputs)
- output_batch = pipe(**batched_inputs)
-
- assert output_batch[0].shape[0] == batch_size
-
- max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
- assert max_diff < expected_max_diff
-
- @require_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_prompt_embeds(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs.pop("prompt")
- inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
- pipe(**inputs)
-
- def test_free_init(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs_normal = self.get_dummy_inputs(torch_device)
- frames_normal = pipe(**inputs_normal).frames[0]
-
- pipe.enable_free_init(
- num_iters=2,
- use_fast_sampling=True,
- method="butterworth",
- order=4,
- spatial_stop_frequency=0.25,
- temporal_stop_frequency=0.25,
- )
- inputs_enable_free_init = self.get_dummy_inputs(torch_device)
- frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
-
- pipe.disable_free_init()
- inputs_disable_free_init = self.get_dummy_inputs(torch_device)
- frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
-
- sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
- max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
- self.assertGreater(
- sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
- )
- self.assertLess(
- max_diff_disabled,
- 1e-4,
- "Disabling of FreeInit should lead to results similar to the default pipeline results",
- )
-
- def test_free_init_with_schedulers(self):
- components = self.get_dummy_components()
- pipe: PIAPipeline = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs_normal = self.get_dummy_inputs(torch_device)
- frames_normal = pipe(**inputs_normal).frames[0]
-
- schedulers_to_test = [
- DPMSolverMultistepScheduler.from_config(
- components["scheduler"].config,
- timestep_spacing="linspace",
- beta_schedule="linear",
- algorithm_type="dpmsolver++",
- steps_offset=1,
- clip_sample=False,
- ),
- LCMScheduler.from_config(
- components["scheduler"].config,
- timestep_spacing="linspace",
- beta_schedule="linear",
- steps_offset=1,
- clip_sample=False,
- ),
- ]
- components.pop("scheduler")
-
- for scheduler in schedulers_to_test:
- components["scheduler"] = scheduler
- pipe: PIAPipeline = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
-
- inputs = self.get_dummy_inputs(torch_device)
- frames_enable_free_init = pipe(**inputs).frames[0]
- sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
-
- self.assertGreater(
- sum_enabled,
- 1e1,
- "Enabling of FreeInit should lead to results different from the default pipeline results",
- )
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- for component in pipe.components.values():
- if hasattr(component, "set_default_attn_processor"):
- component.set_default_attn_processor()
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_without_offload = pipe(**inputs).frames[0]
- output_without_offload = (
- output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
- )
-
- pipe.enable_xformers_memory_efficient_attention()
- inputs = self.get_dummy_inputs(torch_device)
- output_with_offload = pipe(**inputs).frames[0]
- output_with_offload = (
- output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
- )
-
- max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
- self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py
index 4e2c4dcdd9cb..3db7c9fa1b0c 100644
--- a/tests/pipelines/pipeline_params.py
+++ b/tests/pipelines/pipeline_params.py
@@ -20,12 +20,6 @@
]
)
-TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
-
-TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
-
-IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
-
IMAGE_VARIATION_PARAMS = frozenset(
[
"image",
@@ -35,8 +29,6 @@
]
)
-IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
-
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
[
"prompt",
@@ -50,8 +42,6 @@
]
)
-TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
-
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
# Text guided image variation with an image mask
@@ -67,8 +57,6 @@
]
)
-TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
-
IMAGE_INPAINTING_PARAMS = frozenset(
[
# image variation with an image mask
@@ -80,8 +68,6 @@
]
)
-IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
-
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
"example_image",
@@ -93,20 +79,12 @@
]
)
-IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
+UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
-UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
-
-UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
-
-UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
-
-UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
-
TEXT_TO_AUDIO_PARAMS = frozenset(
[
"prompt",
@@ -119,11 +97,38 @@
]
)
-TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
-TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
+UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
-TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
+# image params
+TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
+
+IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
+
+
+# batch params
+TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
+
+IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
+
+TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
+
+TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
+
+IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
+
+IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
+
+UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
+
+UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
+
+TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
+
+TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
+
+# callback params
+TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py
index ea5cfcef86fd..fd41c9887dcc 100644
--- a/tests/pipelines/pixart_alpha/test_pixart.py
+++ b/tests/pipelines/pixart_alpha/test_pixart.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,7 +27,8 @@
PixArtAlphaPipeline,
PixArtTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -35,7 +36,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py
index b220afcfc25a..2cb80df81adf 100644
--- a/tests/pipelines/pixart_sigma/test_pixart.py
+++ b/tests/pipelines/pixart_sigma/test_pixart.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,7 +27,8 @@
PixArtSigmaPipeline,
PixArtTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -35,7 +36,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -260,9 +260,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -276,15 +276,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
@slow
diff --git a/tests/pipelines/pndm/test_pndm.py b/tests/pipelines/pndm/test_pndm.py
index 5efb244919da..61d6efe88ccd 100644
--- a/tests/pipelines/pndm/test_pndm.py
+++ b/tests/pipelines/pndm/test_pndm.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +19,8 @@
import torch
from diffusers import PNDMPipeline, PNDMScheduler, UNet2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch, torch_device
+
+from ...testing_utils import enable_full_determinism, nightly, require_torch, torch_device
enable_full_determinism()
diff --git a/tests/pipelines/unclip/__init__.py b/tests/pipelines/prx/__init__.py
similarity index 100%
rename from tests/pipelines/unclip/__init__.py
rename to tests/pipelines/prx/__init__.py
diff --git a/tests/pipelines/prx/test_pipeline_prx.py b/tests/pipelines/prx/test_pipeline_prx.py
new file mode 100644
index 000000000000..46c6a5760e22
--- /dev/null
+++ b/tests/pipelines/prx/test_pipeline_prx.py
@@ -0,0 +1,265 @@
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoTokenizer
+from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
+from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
+
+from diffusers.models import AutoencoderDC, AutoencoderKL
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import is_transformers_version
+
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+@pytest.mark.xfail(
+ condition=is_transformers_version(">", "4.57.1"),
+ reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
+ strict=False,
+)
+class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = PRXPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ @classmethod
+ def setUpClass(cls):
+ # Ensure PRXPipeline has an _execution_device property expected by __call__
+ if not isinstance(getattr(PRXPipeline, "_execution_device", None), property):
+ try:
+ setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
+ except Exception:
+ pass
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = PRXTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ context_in_dim=8,
+ hidden_size=8,
+ mlp_ratio=2.0,
+ num_heads=2,
+ depth=1,
+ axes_dim=[2, 2],
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0,
+ scaling_factor=1.0,
+ ).eval()
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+ tokenizer.model_max_length = 64
+
+ torch.manual_seed(0)
+
+ encoder_params = {
+ "vocab_size": tokenizer.vocab_size,
+ "hidden_size": 8,
+ "intermediate_size": 16,
+ "num_hidden_layers": 1,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 1,
+ "head_dim": 4,
+ "max_position_embeddings": 64,
+ "layer_types": ["full_attention"],
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ "dropout_rate": 0.0,
+ "hidden_activation": "gelu_pytorch_tanh",
+ "rms_norm_eps": 1e-06,
+ "attn_logit_softcapping": 50.0,
+ "final_logit_softcapping": 30.0,
+ "query_pre_attn_scalar": 4,
+ "rope_theta": 10000.0,
+ "sliding_window": 4096,
+ }
+ encoder_config = T5GemmaModuleConfig(**encoder_params)
+ text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
+ text_encoder = T5GemmaEncoder(text_encoder_config)
+
+ return {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ return {
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "output_type": "pt",
+ "use_resolution_binning": False,
+ }
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = PRXPipeline(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ try:
+ pipe.register_to_config(_execution_device="cpu")
+ except Exception:
+ pass
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.zeros(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ components = self.get_dummy_components()
+ pipe = PRXPipeline(**components)
+ pipe = pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+ try:
+ pipe.register_to_config(_execution_device="cpu")
+ except Exception:
+ pass
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ for tensor_name in callback_kwargs.keys():
+ assert tensor_name in pipe._callback_tensor_inputs
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+ for tensor_name in callback_kwargs.keys():
+ assert tensor_name in pipe._callback_tensor_inputs
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs("cpu")
+
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ _ = pipe(**inputs)[0]
+
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ _ = pipe(**inputs)[0]
+
+ def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ def to_np_local(tensor):
+ if isinstance(tensor, torch.Tensor):
+ return tensor.detach().cpu().numpy()
+ return tensor
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max()
+ self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
+
+ def test_inference_with_autoencoder_dc(self):
+ """Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
+ device = "cpu"
+
+ components = self.get_dummy_components()
+
+ torch.manual_seed(0)
+ vae_dc = AutoencoderDC(
+ in_channels=3,
+ latent_channels=4,
+ attention_head_dim=2,
+ encoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ decoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ encoder_block_out_channels=(8, 8),
+ decoder_block_out_channels=(8, 8),
+ encoder_qkv_multiscales=((), (5,)),
+ decoder_qkv_multiscales=((), (5,)),
+ encoder_layers_per_block=(1, 1),
+ decoder_layers_per_block=(1, 1),
+ upsample_block_type="interpolate",
+ downsample_block_type="stride_conv",
+ decoder_norm_types="rms_norm",
+ decoder_act_fns="silu",
+ ).eval()
+
+ components["vae"] = vae_dc
+
+ pipe = PRXPipeline(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ expected_scale_factor = vae_dc.spatial_compression_ratio
+ self.assertEqual(pipe.vae_scale_factor, expected_scale_factor)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.zeros(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
diff --git a/tests/pipelines/unidiffuser/__init__.py b/tests/pipelines/qwenimage/__init__.py
similarity index 100%
rename from tests/pipelines/unidiffuser/__init__.py
rename to tests/pipelines/qwenimage/__init__.py
diff --git a/tests/pipelines/qwenimage/test_qwenimage.py b/tests/pipelines/qwenimage/test_qwenimage.py
new file mode 100644
index 000000000000..8ebfe7d08bc1
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage.py
@@ -0,0 +1,236 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImagePipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImagePipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ # fmt: off
+ latents_mean=[0.0] * 4,
+ latents_std=[1.0] * 4,
+ # fmt: on
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.56331, 0.63677, 0.6015, 0.56369, 0.58166, 0.55277, 0.57176, 0.63261, 0.41466, 0.35561, 0.56229, 0.48334, 0.49714, 0.52622, 0.40872, 0.50208])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/qwenimage/test_qwenimage_controlnet.py b/tests/pipelines/qwenimage/test_qwenimage_controlnet.py
new file mode 100644
index 000000000000..188106b49b84
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_controlnet.py
@@ -0,0 +1,338 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageControlNetModel,
+ QwenImageControlNetPipeline,
+ QwenImageMultiControlNetModel,
+ QwenImageTransformer2DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from diffusers.utils.torch_utils import randn_tensor
+
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageControlNetPipeline
+ params = (TEXT_TO_IMAGE_PARAMS | frozenset(["control_image", "controlnet_conditioning_scale"])) - {
+ "cross_attention_kwargs"
+ }
+ batch_params = frozenset(["prompt", "negative_prompt", "control_image"])
+ image_params = frozenset(["control_image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "control_image",
+ "controlnet_conditioning_scale",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ controlnet = QwenImageControlNetModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1_000_000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "controlnet": controlnet,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ control_image = randn_tensor(
+ (1, 3, 32, 32),
+ generator=generator,
+ device=torch.device(device),
+ dtype=torch.float32,
+ )
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "control_image": control_image,
+ "controlnet_conditioning_scale": 0.5,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_qwen_controlnet(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # Expected slice from the generated image
+ expected_slice = torch.tensor(
+ [
+ 0.4726,
+ 0.5549,
+ 0.6324,
+ 0.6548,
+ 0.4968,
+ 0.4639,
+ 0.4749,
+ 0.4898,
+ 0.4725,
+ 0.4645,
+ 0.4435,
+ 0.3339,
+ 0.3400,
+ 0.4630,
+ 0.3879,
+ 0.4406,
+ ]
+ )
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_qwen_controlnet_multicondition(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+
+ components["controlnet"] = QwenImageMultiControlNetModel([components["controlnet"]])
+
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ control_image = inputs["control_image"]
+ inputs["control_image"] = [control_image, control_image]
+ inputs["controlnet_conditioning_scale"] = [0.5, 0.5]
+
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ # Expected slice from the generated image
+ expected_slice = torch.tensor(
+ [
+ 0.6239,
+ 0.6642,
+ 0.5768,
+ 0.6039,
+ 0.5270,
+ 0.5070,
+ 0.5006,
+ 0.5271,
+ 0.4506,
+ 0.3085,
+ 0.3435,
+ 0.5152,
+ 0.5096,
+ 0.5422,
+ 0.4286,
+ 0.5752,
+ ]
+ )
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ inputs["control_image"] = randn_tensor(
+ (1, 3, 128, 128),
+ generator=inputs["generator"],
+ device=torch.device(generator_device),
+ dtype=torch.float32,
+ )
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ inputs["control_image"] = randn_tensor(
+ (1, 3, 128, 128),
+ generator=inputs["generator"],
+ device=torch.device(generator_device),
+ dtype=torch.float32,
+ )
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/qwenimage/test_qwenimage_edit.py b/tests/pipelines/qwenimage/test_qwenimage_edit.py
new file mode 100644
index 000000000000..058548cf5f1b
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_edit.py
@@ -0,0 +1,243 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageEditPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageEditPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["prompt", "image"])
+ image_params = frozenset(["image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
+
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "image": Image.new("RGB", (32, 32)),
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
+ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
+ super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
diff --git a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
new file mode 100644
index 000000000000..6faf34728286
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
@@ -0,0 +1,253 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageEditPlusPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageEditPlusPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["prompt", "image"])
+ image_params = frozenset(["image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
+
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = Image.new("RGB", (32, 32))
+ inputs = {
+ "prompt": "dance monkey",
+ "image": [image, image],
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
+ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
+ super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_num_images_per_prompt():
+ super().test_num_images_per_prompt()
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_inference_batch_consistent():
+ super().test_inference_batch_consistent()
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_inference_batch_single_identical():
+ super().test_inference_batch_single_identical()
diff --git a/tests/pipelines/qwenimage/test_qwenimage_img2img.py b/tests/pipelines/qwenimage/test_qwenimage_img2img.py
new file mode 100644
index 000000000000..07e683ec7f5a
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_img2img.py
@@ -0,0 +1,218 @@
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageImg2ImgPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = QwenImageImg2ImgPipeline
+ params = frozenset(["prompt", "image", "height", "width", "guidance_scale", "true_cfg_scale", "strength"])
+ batch_params = frozenset(["prompt", "image"])
+ image_params = frozenset(["image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_attention_slicing = True
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * 4,
+ latents_std=[1.0] * 4,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ return {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs).images[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs).images[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs).images[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/qwenimage/test_qwenimage_inpaint.py b/tests/pipelines/qwenimage/test_qwenimage_inpaint.py
new file mode 100644
index 000000000000..b564624540c3
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_inpaint.py
@@ -0,0 +1,233 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageInpaintPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageInpaintPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ # fmt: off
+ latents_mean=[0.0] * 4,
+ latents_std=[1.0] * 4,
+ # fmt: on
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ mask_image = torch.ones((1, 1, 32, 32)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "image": image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py
index aa5d5c7ce463..f23303c966e5 100644
--- a/tests/pipelines/sana/test_sana.py
+++ b/tests/pipelines/sana/test_sana.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,14 +21,15 @@
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ IS_GITHUB_ACTIONS,
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -304,6 +305,10 @@ def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
+
@slow
@require_torch_accelerator
diff --git a/tests/pipelines/sana/test_sana_controlnet.py b/tests/pipelines/sana/test_sana_controlnet.py
new file mode 100644
index 000000000000..df14d935edf5
--- /dev/null
+++ b/tests/pipelines/sana/test_sana_controlnet.py
@@ -0,0 +1,329 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import (
+ AutoencoderDC,
+ FlowMatchEulerDiscreteScheduler,
+ SanaControlNetModel,
+ SanaControlNetPipeline,
+ SanaTransformer2DModel,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaControlNetPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ controlnet = SanaControlNetModel(
+ patch_size=1,
+ in_channels=4,
+ out_channels=4,
+ num_layers=1,
+ num_attention_heads=2,
+ attention_head_dim=4,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=4,
+ cross_attention_dim=8,
+ caption_channels=8,
+ sample_size=32,
+ )
+
+ torch.manual_seed(0)
+ transformer = SanaTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ out_channels=4,
+ num_layers=1,
+ num_attention_heads=2,
+ attention_head_dim=4,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=4,
+ cross_attention_dim=8,
+ caption_channels=8,
+ sample_size=32,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderDC(
+ in_channels=3,
+ latent_channels=4,
+ attention_head_dim=2,
+ encoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ decoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ encoder_block_out_channels=(8, 8),
+ decoder_block_out_channels=(8, 8),
+ encoder_qkv_multiscales=((), (5,)),
+ decoder_qkv_multiscales=((), (5,)),
+ encoder_layers_per_block=(1, 1),
+ decoder_layers_per_block=[1, 1],
+ downsample_block_type="conv",
+ upsample_block_type="interpolate",
+ decoder_norm_types="rms_norm",
+ decoder_act_fns="silu",
+ scaling_factor=0.41407,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "controlnet": controlnet,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
+ inputs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": None,
+ "control_image": control_image,
+ "controlnet_conditioning_scale": 1.0,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.randn(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
diff --git a/tests/pipelines/sana/test_sana_sprint.py b/tests/pipelines/sana/test_sana_sprint.py
index d006c2b986ca..0d45205ea8c7 100644
--- a/tests/pipelines/sana/test_sana_sprint.py
+++ b/tests/pipelines/sana/test_sana_sprint.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,11 +20,8 @@
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- torch_device,
-)
+from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -300,3 +297,7 @@ def test_inference_batch_single_identical(self):
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
diff --git a/tests/pipelines/sana/test_sana_sprint_img2img.py b/tests/pipelines/sana/test_sana_sprint_img2img.py
new file mode 100644
index 000000000000..5de5c7f44606
--- /dev/null
+++ b/tests/pipelines/sana/test_sana_sprint_img2img.py
@@ -0,0 +1,315 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
+)
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaSprintImg2ImgPipeline
+ params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {
+ "negative_prompt",
+ "negative_prompt_embeds",
+ }
+ batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"negative_prompt"}
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = SanaTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ out_channels=4,
+ num_layers=1,
+ num_attention_heads=2,
+ attention_head_dim=4,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=4,
+ cross_attention_dim=8,
+ caption_channels=8,
+ sample_size=32,
+ qk_norm="rms_norm_across_heads",
+ guidance_embeds=True,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderDC(
+ in_channels=3,
+ latent_channels=4,
+ attention_head_dim=2,
+ encoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ decoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ encoder_block_out_channels=(8, 8),
+ decoder_block_out_channels=(8, 8),
+ encoder_qkv_multiscales=((), (5,)),
+ decoder_qkv_multiscales=((), (5,)),
+ encoder_layers_per_block=(1, 1),
+ decoder_layers_per_block=[1, 1],
+ downsample_block_type="conv",
+ upsample_block_type="interpolate",
+ decoder_norm_types="rms_norm",
+ decoder_act_fns="silu",
+ scaling_factor=0.41407,
+ )
+
+ torch.manual_seed(0)
+ scheduler = SCMScheduler()
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
+ inputs = {
+ "prompt": "",
+ "image": image,
+ "strength": 0.5,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": None,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.randn(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ @unittest.skip("vae tiling resulted in a small margin over the expected max diff, so skipping this test for now")
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
diff --git a/tests/pipelines/wuerstchen/__init__.py b/tests/pipelines/sana_video/__init__.py
similarity index 100%
rename from tests/pipelines/wuerstchen/__init__.py
rename to tests/pipelines/sana_video/__init__.py
diff --git a/tests/pipelines/sana_video/test_sana_video.py b/tests/pipelines/sana_video/test_sana_video.py
new file mode 100644
index 000000000000..9f360a942a64
--- /dev/null
+++ b/tests/pipelines/sana_video/test_sana_video.py
@@ -0,0 +1,225 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import AutoencoderKLWan, DPMSolverMultistepScheduler, SanaVideoPipeline, SanaVideoTransformer3DModel
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = DPMSolverMultistepScheduler()
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ torch.manual_seed(0)
+ transformer = SanaVideoTransformer3DModel(
+ in_channels=16,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=2,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=12,
+ cross_attention_dim=24,
+ caption_channels=8,
+ mlp_ratio=2.5,
+ dropout=0.0,
+ attention_bias=False,
+ sample_size=8,
+ patch_size=(1, 2, 2),
+ norm_elementwise_affine=False,
+ norm_eps=1e-6,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": [],
+ "use_resolution_binning": False,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_save_load_local(self, expected_max_difference=5e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
+
+ def test_save_load_float16(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_save_load_float16(expected_max_diff=0.2)
+
+
+@slow
+@require_torch_accelerator
+class SanaVideoPipelineIntegrationTests(unittest.TestCase):
+ prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_sana_video_480p(self):
+ pass
diff --git a/tests/pipelines/sana_video/test_sana_video_i2v.py b/tests/pipelines/sana_video/test_sana_video_i2v.py
new file mode 100644
index 000000000000..36a646ca528f
--- /dev/null
+++ b/tests/pipelines/sana_video/test_sana_video_i2v.py
@@ -0,0 +1,238 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ SanaImageToVideoPipeline,
+ SanaVideoTransformer3DModel,
+)
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ torch.manual_seed(0)
+ transformer = SanaVideoTransformer3DModel(
+ in_channels=16,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=2,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=12,
+ cross_attention_dim=24,
+ caption_channels=8,
+ mlp_ratio=2.5,
+ dropout=0.0,
+ attention_bias=False,
+ sample_size=8,
+ patch_size=(1, 2, 2),
+ norm_elementwise_affine=False,
+ norm_eps=1e-6,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ # Create a dummy image input (PIL Image)
+ image = Image.new("RGB", (32, 32))
+
+ inputs = {
+ "image": image,
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": [],
+ "use_resolution_binning": False,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_save_load_local(self, expected_max_difference=5e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skip("Skipping fp16 test as model is trained with bf16")
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
+
+ @unittest.skip("Skipping fp16 test as model is trained with bf16")
+ def test_save_load_float16(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_save_load_float16(expected_max_diff=0.2)
+
+
+@slow
+@require_torch_accelerator
+class SanaVideoPipelineIntegrationTests(unittest.TestCase):
+ prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_sana_video_480p(self):
+ pass
diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
deleted file mode 100644
index 6cd431f02d58..000000000000
--- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
+++ /dev/null
@@ -1,621 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipeline as StableDiffusionPipeline
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- nightly,
- require_accelerator,
- require_torch_gpu,
- torch_device,
-)
-
-
-enable_full_determinism()
-
-
-class SafeDiffusionPipelineFastTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- @property
- def dummy_image(self):
- batch_size = 1
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
- return image
-
- @property
- def dummy_cond_unet(self):
- torch.manual_seed(0)
- model = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- return model
-
- @property
- def dummy_vae(self):
- torch.manual_seed(0)
- model = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- return model
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config)
-
- @property
- def dummy_extractor(self):
- def extract(*args, **kwargs):
- class Out:
- def __init__(self):
- self.pixel_values = torch.ones([0])
-
- def to(self, device):
- self.pixel_values.to(device)
- return self
-
- return Out()
-
- return extract
-
- def test_semantic_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5753, 0.6114, 0.5001, 0.5034, 0.5470, 0.4729, 0.4971, 0.4867, 0.4867])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_semantic_diffusion_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
-
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5122, 0.5712, 0.4825, 0.5053, 0.5646, 0.4769, 0.5179, 0.4894, 0.4994])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_semantic_diffusion_no_safety_checker(self):
- pipe = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
- )
- assert isinstance(pipe, StableDiffusionPipeline)
- assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
- assert pipe.safety_checker is None
-
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- # check that there's no error when saving a pipeline with one of the models being None
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
- pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
-
- # sanity check that the pipeline still works
- assert pipe.safety_checker is None
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- @require_accelerator
- def test_semantic_diffusion_fp16(self):
- """Test that stable diffusion works with fp16"""
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # put models in fp16
- unet = unet.half()
- vae = vae.half()
- bert = bert.half()
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images
-
- assert image.shape == (1, 64, 64, 3)
-
-
-@nightly
-@require_torch_gpu
-class SemanticDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_positive_guidance(self):
- torch_device = "cuda"
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a photo of a cat"
- edit = {
- "editing_prompt": ["sunglasses"],
- "reverse_editing_direction": [False],
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 6,
- "edit_threshold": 0.95,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 3
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.34673113,
- 0.38492733,
- 0.37597352,
- 0.34086335,
- 0.35650748,
- 0.35579205,
- 0.3384763,
- 0.34340236,
- 0.3573271,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.41887826,
- 0.37728766,
- 0.30138272,
- 0.41416335,
- 0.41664985,
- 0.36283392,
- 0.36191246,
- 0.43364465,
- 0.43001732,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_negative_guidance(self):
- torch_device = "cuda"
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "an image of a crowded boulevard, realistic, 4k"
- edit = {
- "editing_prompt": "crowd, crowded, people",
- "reverse_editing_direction": True,
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 8.3,
- "edit_threshold": 0.9,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 9
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.43497998,
- 0.91814065,
- 0.7540739,
- 0.55580205,
- 0.8467265,
- 0.5389691,
- 0.62574506,
- 0.58897763,
- 0.50926757,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.3089719,
- 0.30500144,
- 0.29016042,
- 0.30630964,
- 0.325687,
- 0.29419225,
- 0.2908091,
- 0.28723598,
- 0.27696294,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_multi_cond_guidance(self):
- torch_device = "cuda"
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a castle next to a river"
- edit = {
- "editing_prompt": ["boat on a river, boat", "monet, impression, sunrise"],
- "reverse_editing_direction": False,
- "edit_warmup_steps": [15, 18],
- "edit_guidance_scale": 6,
- "edit_threshold": [0.9, 0.8],
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 48
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.75163555,
- 0.76037145,
- 0.61785,
- 0.9189673,
- 0.8627701,
- 0.85189694,
- 0.8512813,
- 0.87012076,
- 0.8312857,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.73553365,
- 0.7537271,
- 0.74341905,
- 0.66480356,
- 0.6472925,
- 0.63039416,
- 0.64812905,
- 0.6749717,
- 0.6517102,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_guidance_fp16(self):
- torch_device = "cuda"
- pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
- )
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a photo of a cat"
- edit = {
- "editing_prompt": ["sunglasses"],
- "reverse_editing_direction": [False],
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 6,
- "edit_threshold": 0.95,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 3
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.34887695,
- 0.3876953,
- 0.375,
- 0.34423828,
- 0.3581543,
- 0.35717773,
- 0.3383789,
- 0.34570312,
- 0.359375,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.42285156,
- 0.36914062,
- 0.29077148,
- 0.42041016,
- 0.41918945,
- 0.35498047,
- 0.3618164,
- 0.4423828,
- 0.43115234,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py
index 6cf643fe47a6..99fd28692981 100644
--- a/tests/pipelines/shap_e/test_shap_e.py
+++ b/tests/pipelines/shap_e/test_shap_e.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,8 +21,14 @@
from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEPipeline
from diffusers.pipelines.shap_e import ShapERenderer
-from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, torch_device
+from ...testing_utils import (
+ backend_empty_cache,
+ load_numpy,
+ nightly,
+ require_torch_accelerator,
+ torch_device,
+)
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -222,19 +228,19 @@ def test_sequential_cpu_offload_forward_pass(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class ShapEPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_shap_e(self):
expected_image = load_numpy(
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index ac7096874b31..b1867db249ea 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -1,4 +1,4 @@
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,15 +22,16 @@
from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEImg2ImgPipeline
from diffusers.pipelines.shap_e import ShapERenderer
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
floats_tensor,
load_image,
load_numpy,
nightly,
- require_torch_gpu,
+ require_torch_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -250,23 +251,23 @@ def test_sequential_cpu_offload_forward_pass(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class ShapEImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_shap_e_img2img(self):
input_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/shap_e/corgi.png"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/shap_e/corgi.png"
)
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
diff --git a/tests/pipelines/skyreels_v2/__init__.py b/tests/pipelines/skyreels_v2/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2.py b/tests/pipelines/skyreels_v2/test_skyreels_v2.py
new file mode 100644
index 000000000000..1bcec877c30d
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2.py
@@ -0,0 +1,137 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2Pipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+
+from ...testing_utils import (
+ enable_full_determinism,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2Pipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=8.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py
new file mode 100644
index 000000000000..74235d59efd6
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py
@@ -0,0 +1,137 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2DiffusionForcingPipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+
+from ...testing_utils import (
+ enable_full_determinism,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2DiffusionForcingPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2DiffusionForcingPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=8.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py
new file mode 100644
index 000000000000..f0cbc710df05
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py
@@ -0,0 +1,215 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2DiffusionForcingImageToVideoPipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+
+from ...testing_utils import enable_full_determinism
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2DiffusionForcingImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass
+
+
+class SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests(SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests):
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ pos_embed_seq_len=2 * (4 * 4 + 1),
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ last_image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "last_image": last_image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py
new file mode 100644
index 000000000000..1b0b23318e63
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py
@@ -0,0 +1,201 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2DiffusionForcingVideoToVideoPipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2DiffusionForcingVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2DiffusionForcingVideoToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["video", "prompt", "negative_prompt"])
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ video = [Image.new("RGB", (16, 16))] * 7
+ inputs = {
+ "video": video,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 4,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "overlap_history": 3,
+ "num_frames": 17,
+ "base_num_frames": 5,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ total_frames = len(inputs["video"]) + inputs["num_frames"]
+ expected_shape = (total_frames, 3, 16, 16)
+ self.assertEqual(generated_video.shape, expected_shape)
+ expected_video = torch.randn(*expected_shape)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_cfg(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ if "guidance_scale" not in sig.parameters:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ # Track the number of callback calls for diffusion forcing pipelines
+ callback_call_count = [0] # Use list to make it mutable in closure
+
+ def callback_increase_guidance(pipe, i, t, callback_kwargs):
+ pipe._guidance_scale += 1.0
+ callback_call_count[0] += 1
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # use cfg guidance because some pipelines modify the shape of the latents
+ # outside of the denoising loop
+ inputs["guidance_scale"] = 2.0
+ inputs["callback_on_step_end"] = callback_increase_guidance
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ _ = pipe(**inputs)[0]
+
+ # For diffusion forcing pipelines, use the actual callback count
+ # since they run multiple iterations with nested denoising loops
+ expected_guidance_scale = inputs["guidance_scale"] + callback_call_count[0]
+
+ assert pipe.guidance_scale == expected_guidance_scale
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors"
+ )
+ def test_float16_inference(self):
+ pass
+
+ @unittest.skip(
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
+ )
+ def test_save_load_float16(self):
+ pass
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py
new file mode 100644
index 000000000000..784f701a29d2
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py
@@ -0,0 +1,220 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2ImageToVideoPipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+
+from ...testing_utils import enable_full_determinism
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2ImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=32,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=32, size=32)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_inference_with_last_image(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ torch.manual_seed(0)
+ components["transformer"] = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ pos_embed_seq_len=2 * (4 * 4 + 1),
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=4,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ components["image_encoder"] = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ components["image_processor"] = CLIPImageProcessor(crop_size=4, size=4)
+
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image_height = 16
+ image_width = 16
+ last_image = Image.new("RGB", (image_width, image_height))
+ inputs["last_image"] = last_image
+
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/stable_audio/test_stable_audio.py b/tests/pipelines/stable_audio/test_stable_audio.py
index 01df82056ce2..dd03f4d07f07 100644
--- a/tests/pipelines/stable_audio/test_stable_audio.py
+++ b/tests/pipelines/stable_audio/test_stable_audio.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,8 +32,15 @@
StableAudioProjectionModel,
)
from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
+from ...testing_utils import (
+ Expectations,
+ backend_empty_cache,
+ enable_full_determinism,
+ nightly,
+ require_torch_accelerator,
+ torch_device,
+)
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -419,17 +426,17 @@ def test_encode_prompt_works_in_isolation(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableAudioPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -459,9 +466,15 @@ def test_stable_audio(self):
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[0, 447590:447600]
# fmt: off
- expected_slice = np.array(
- [-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array([-0.0285, 0.1083, 0.1863, 0.3165, 0.5312, 0.6971, 0.6958, 0.6177, 0.5598, 0.5048]),
+ ("cuda", 7): np.array([-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060]),
+ ("cuda", 8): np.array([-0.0285, 0.1082, 0.1862, 0.3163, 0.5306, 0.6964, 0.6953, 0.6172, 0.5593, 0.5044]),
+ }
)
- # fmt: one
+ # fmt: on
+
+ expected_slice = expected_slices.get_expectation()
max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max()
assert max_diff < 1.5e-3
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
index 1765f3a02242..afa0db39f3fa 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,8 +22,8 @@
from diffusers import DDPMWuerstchenScheduler, StableCascadeCombinedPipeline
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
@@ -198,12 +198,12 @@ def test_stable_cascade(self):
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ )
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
+ f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+ )
@require_torch_accelerator
def test_offloads(self):
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
index afcd8fca71ca..5b3acb8705b3 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,7 +23,9 @@
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
@@ -34,8 +36,6 @@
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -304,7 +304,8 @@ def test_stable_cascade_decoder(self):
generator = torch.Generator(device="cpu").manual_seed(0)
image_embedding = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt",
+ map_location=torch_device,
)
image = pipe(
@@ -320,4 +321,4 @@ def test_stable_cascade_decoder(self):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/stable_cascade_decoder_image.npy"
)
max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
- assert max_diff < 1e-4
+ assert max_diff < 2e-4
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
index 0374de9b0219..0bc821b7e64f 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,13 +17,16 @@
import unittest
import numpy as np
+import pytest
import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
from diffusers.models import StableCascadeUNet
+from diffusers.utils import is_transformers_version
from diffusers.utils.import_utils import is_peft_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
@@ -153,6 +156,11 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.57.1"),
+ reason="Test fails with the latest transformers version",
+ strict=False,
+ )
def test_wuerstchen_prior(self):
device = "cpu"
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
index f7036dee47f0..62414f3f1947 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,8 +27,8 @@
OnnxStableDiffusionPipeline,
PNDMScheduler,
)
-from diffusers.utils.testing_utils import is_onnx_available, nightly, require_onnxruntime, require_torch_gpu
+from ...testing_utils import is_onnx_available, nightly, require_onnxruntime, require_torch_gpu
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
index c73ed0f6afe8..28d1d0f37ff8 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,7 +26,8 @@
OnnxStableDiffusionImg2ImgPipeline,
PNDMScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
floats_tensor,
is_onnx_available,
load_image,
@@ -34,7 +35,6 @@
require_onnxruntime,
require_torch_gpu,
)
-
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
index 09048b5c0e0f..1d46ff9a2f5f 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,14 +18,14 @@
import numpy as np
from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionInpaintPipeline
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
is_onnx_available,
load_image,
nightly,
require_onnxruntime,
require_torch_gpu,
)
-
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py
index 2df64ad1d685..55d9d38d64bd 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py
@@ -26,7 +26,8 @@
OnnxStableDiffusionUpscalePipeline,
PNDMScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
floats_tensor,
is_onnx_available,
load_image,
@@ -34,7 +35,6 @@
require_onnxruntime,
require_torch_gpu,
)
-
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
@@ -42,6 +42,10 @@
import onnxruntime as ort
+# TODO: (Dhruv) Update hub_checkpoint repo_id
+@unittest.skip(
+ "There is a potential backdoor vulnerability in the hub_checkpoint. Skip running this test until resolved"
+)
class OnnxStableDiffusionUpscalePipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
# TODO: is there an appropriate internal test set?
hub_checkpoint = "ssube/stable-diffusion-x4-upscaler-onnx"
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index 6e17b86639ea..c9d9525b2e45 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
import gc
import tempfile
import time
-import traceback
import unittest
import numpy as np
@@ -42,28 +41,24 @@
UNet2DConditionModel,
logging,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
- is_torch_compile,
- load_image,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
require_accelerate_version_greater,
- require_torch_2,
require_torch_accelerator,
require_torch_multi_accelerator,
- run_test_in_subprocess,
skip_mps,
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -81,39 +76,6 @@
enable_full_determinism()
-# Will be run via run_test_in_subprocess
-def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
- error = None
- try:
- inputs = in_queue.get(timeout=timeout)
- torch_device = inputs.pop("torch_device")
- seed = inputs.pop("seed")
- inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
-
- sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
- sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
-
- sd_pipe.unet.to(memory_format=torch.channels_last)
- sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- sd_pipe.set_progress_bar_config(disable=None)
-
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239])
-
- assert np.abs(image_slice - expected_slice).max() < 5e-3
- except Exception:
- error = f"{traceback.format_exc()}"
-
- results = {"error": error}
- out_queue.put(results, timeout=timeout)
- out_queue.join()
-
-
class StableDiffusionPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
@@ -293,15 +255,15 @@ def test_stable_diffusion_ays(self):
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
- assert (
- np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
- ), "ays timesteps and ays sigmas should have the same outputs"
- assert (
- np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
- ), "use ays timesteps should have different outputs"
- assert (
- np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
- ), "use ays sigmas should have different outputs"
+ assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
+ "ays timesteps and ays sigmas should have the same outputs"
+ )
+ assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
+ "use ays timesteps should have different outputs"
+ )
+ assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
+ "use ays sigmas should have different outputs"
+ )
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
@@ -656,9 +618,9 @@ def test_freeu_enabled(self):
sd_pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
output_freeu = sd_pipe(prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)).images
- assert not np.allclose(
- output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
- ), "Enabling of FreeU should lead to different results."
+ assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
+ "Enabling of FreeU should lead to different results."
+ )
def test_freeu_disabled(self):
components = self.get_dummy_components()
@@ -681,9 +643,9 @@ def test_freeu_disabled(self):
prompt, num_inference_steps=1, output_type="np", generator=torch.manual_seed(0)
).images
- assert np.allclose(
- output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]
- ), "Disabling of FreeU should lead to results similar to the default pipeline results."
+ assert np.allclose(output[0, -3:, -3:, -1], output_no_freeu[0, -3:, -3:, -1]), (
+ "Disabling of FreeU should lead to results similar to the default pipeline results."
+ )
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -706,15 +668,15 @@ def test_fused_qkv_projections(self):
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_pipeline_interrupt(self):
components = self.get_dummy_components()
@@ -1224,40 +1186,6 @@ def test_stable_diffusion_textual_inversion_with_sequential_cpu_offload(self):
max_diff = np.abs(expected_image - image).max()
assert max_diff < 8e-1
- @is_torch_compile
- @require_torch_2
- def test_stable_diffusion_compile(self):
- seed = 0
- inputs = self.get_inputs(torch_device, seed=seed)
- # Can't pickle a Generator object
- del inputs["generator"]
- inputs["torch_device"] = torch_device
- inputs["seed"] = seed
- run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=inputs)
-
- def test_stable_diffusion_lcm(self):
- unet = UNet2DConditionModel.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", subfolder="unet")
- sd_pipe = StableDiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", unet=unet).to(torch_device)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 6
- inputs["output_type"] = "pil"
-
- image = sd_pipe(**inputs).images[0]
-
- expected_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_full/stable_diffusion_lcm.png"
- )
-
- image = sd_pipe.image_processor.pil_to_numpy(image)
- expected_image = sd_pipe.image_processor.pil_to_numpy(expected_image)
-
- max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
-
- assert max_diff < 1e-2
-
@slow
@require_torch_accelerator
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
index 82b01a74869a..a0b7268b9dd4 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
import gc
import random
-import traceback
import unittest
import numpy as np
@@ -34,25 +33,22 @@
StableDiffusionImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
- is_torch_compile,
load_image,
load_numpy,
nightly,
- require_torch_2,
require_torch_accelerator,
- run_test_in_subprocess,
skip_mps,
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -70,38 +66,6 @@
enable_full_determinism()
-# Will be run via run_test_in_subprocess
-def _test_img2img_compile(in_queue, out_queue, timeout):
- error = None
- try:
- inputs = in_queue.get(timeout=timeout)
- torch_device = inputs.pop("torch_device")
- seed = inputs.pop("seed")
- inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
-
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None)
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.unet.set_default_attn_processor()
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 768, 3)
- expected_slice = np.array([0.0606, 0.0570, 0.0805, 0.0579, 0.0628, 0.0623, 0.0843, 0.1115, 0.0806])
-
- assert np.abs(expected_slice - image_slice).max() < 1e-3
- except Exception:
- error = f"{traceback.format_exc()}"
-
- results = {"error": error}
- out_queue.put(results, timeout=timeout)
- out_queue.join()
-
-
class StableDiffusionImg2ImgPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
@@ -654,17 +618,6 @@ def test_img2img_safety_checker_works(self):
assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}"
assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros
- @is_torch_compile
- @require_torch_2
- def test_img2img_compile(self):
- seed = 0
- inputs = self.get_inputs(torch_device, seed=seed)
- # Can't pickle a Generator object
- del inputs["generator"]
- inputs["torch_device"] = torch_device
- inputs["seed"] = seed
- run_test_in_subprocess(test_case=self, target_func=_test_img2img_compile, inputs=inputs)
-
@nightly
@require_torch_accelerator
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
index e21cf23b8cbf..259806a9479c 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,7 +15,6 @@
import gc
import random
-import traceback
import unittest
import numpy as np
@@ -36,24 +35,22 @@
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
- is_torch_compile,
load_image,
load_numpy,
nightly,
- require_torch_2,
require_torch_accelerator,
- run_test_in_subprocess,
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -70,40 +67,6 @@
enable_full_determinism()
-# Will be run via run_test_in_subprocess
-def _test_inpaint_compile(in_queue, out_queue, timeout):
- error = None
- try:
- inputs = in_queue.get(timeout=timeout)
- torch_device = inputs.pop("torch_device")
- seed = inputs.pop("seed")
- inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
-
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
- "botp/stable-diffusion-v1-5-inpainting", safety_checker=None
- )
- pipe.unet.set_default_attn_processor()
- pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- image = pipe(**inputs).images
- image_slice = image[0, 253:256, 253:256, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0689, 0.0699, 0.0790, 0.0536, 0.0470, 0.0488, 0.041, 0.0508, 0.04179])
- assert np.abs(expected_slice - image_slice).max() < 3e-3
- except Exception:
- error = f"{traceback.format_exc()}"
-
- results = {"error": error}
- out_queue.put(results, timeout=timeout)
- out_queue.join()
-
-
class StableDiffusionInpaintPipelineFastTests(
IPAdapterTesterMixin,
PipelineLatentTesterMixin,
@@ -726,17 +689,6 @@ def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
# make sure that less than 2.2 GB is allocated
assert mem_bytes < 2.2 * 10**9
- @is_torch_compile
- @require_torch_2
- def test_inpaint_compile(self):
- seed = 0
- inputs = self.get_inputs(torch_device, seed=seed)
- # Can't pickle a Generator object
- del inputs["generator"]
- inputs["torch_device"] = torch_device
- inputs["seed"] = seed
- run_test_in_subprocess(test_case=self, target_func=_test_inpaint_compile, inputs=inputs)
-
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"botp/stable-diffusion-v1-5-inpainting", safety_checker=None
@@ -866,7 +818,37 @@ def test_stable_diffusion_inpaint_fp16(self):
image_slice = image[0, 253:256, 253:256, -1].flatten()
assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1343, 0.1406, 0.1440, 0.1504, 0.1729, 0.0989, 0.1807, 0.2822, 0.1179])
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.2063,
+ 0.1731,
+ 0.1553,
+ 0.1741,
+ 0.1772,
+ 0.1077,
+ 0.2109,
+ 0.2407,
+ 0.1243,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.1343,
+ 0.1406,
+ 0.1440,
+ 0.1504,
+ 0.1729,
+ 0.0989,
+ 0.1807,
+ 0.2822,
+ 0.1179,
+ ]
+ ),
+ }
+ )
+ expected_slice = expected_slices.get_expectation()
assert np.abs(expected_slice - image_slice).max() < 5e-2
@@ -933,11 +915,6 @@ def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
# make sure that less than 2.45 GB is allocated
assert mem_bytes < 2.45 * 10**9
- @is_torch_compile
- @require_torch_2
- def test_inpaint_compile(self):
- pass
-
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
vae = AsymmetricAutoencoderKL.from_pretrained(
"cross-attention/asymmetric-autoencoder-kl-x-1-5",
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
index 9721bb02ee3e..4758c5dab44b 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,7 +32,8 @@
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -44,7 +45,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
index 3f9f7e965b40..3b2552b432d3 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,8 @@
UNet2DConditionModel,
logging,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
backend_max_memory_allocated,
@@ -45,7 +46,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
deleted file mode 100644
index c66491b15c66..000000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
+++ /dev/null
@@ -1,266 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- StableDiffusionAttendAndExcitePipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- load_numpy,
- nightly,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-torch.backends.cuda.matmul.allow_tf32 = False
-
-
-@skip_mps
-class StableDiffusionAttendAndExcitePipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionAttendAndExcitePipeline
- test_attention_slicing = False
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"})
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- # Attend and excite requires being able to run a backward pass at
- # inference time. There's no deterministic backward operator for pad
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- torch.use_deterministic_algorithms(False)
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- torch.use_deterministic_algorithms(True)
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a cat and a frog",
- "token_indices": [2, 5],
- "generator": generator,
- "num_inference_steps": 1,
- "guidance_scale": 6.0,
- "output_type": "np",
- "max_iter_to_alter": 2,
- "thresholds": {0: 0.7},
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.6391, 0.6290, 0.4860, 0.5134, 0.5550, 0.4577, 0.5033, 0.5023, 0.4538])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice, expected_max_difference=3e-3)
-
- def test_inference(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- self.assertEqual(image.shape, (1, 64, 64, 3))
- expected_slice = np.array(
- [0.63905364, 0.62897307, 0.48599017, 0.5133624, 0.5550048, 0.45769516, 0.50326973, 0.5023139, 0.45384496]
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_sequential_cpu_offload_forward_pass(self):
- super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
-
- def test_inference_batch_consistent(self):
- # NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
- self._test_inference_batch_consistent(batch_sizes=[1, 2])
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=7e-4)
-
- def test_pt_np_pil_outputs_equivalent(self):
- super().test_pt_np_pil_outputs_equivalent(expected_max_diff=5e-4)
-
- def test_save_load_local(self):
- super().test_save_load_local(expected_max_difference=5e-4)
-
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=4e-4)
-
- def test_karras_schedulers_shape(self):
- super().test_karras_schedulers_shape(num_inference_steps_for_strength_for_iterations=3)
-
- def test_from_pipe_consistent_forward_pass_cpu_offload(self):
- super().test_from_pipe_consistent_forward_pass_cpu_offload(expected_max_diff=5e-3)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@require_torch_accelerator
-@nightly
-class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
- # Attend and excite requires being able to run a backward pass at
- # inference time. There's no deterministic backward operator for pad
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- torch.use_deterministic_algorithms(False)
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- torch.use_deterministic_algorithms(True)
-
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_attend_and_excite_fp16(self):
- generator = torch.manual_seed(51)
-
- pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.to(torch_device)
-
- prompt = "a painting of an elephant with glasses"
- token_indices = [5, 7]
-
- image = pipe(
- prompt=prompt,
- token_indices=token_indices,
- guidance_scale=7.5,
- generator=generator,
- num_inference_steps=5,
- max_iter_to_alter=5,
- output_type="np",
- ).images[0]
-
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/attend-and-excite/elephant_glasses.npy"
- )
- max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
- assert max_diff < 5e-1
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
index 0a0051816162..bea7c099046f 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,7 +36,8 @@
StableDiffusionDepth2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -50,7 +51,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
deleted file mode 100644
index 34ea56664a95..000000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
+++ /dev/null
@@ -1,452 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMInverseScheduler,
- DDIMScheduler,
- DPMSolverMultistepInverseScheduler,
- DPMSolverMultistepScheduler,
- StableDiffusionDiffEditPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class StableDiffusionDiffEditPipelineFastTests(
- PipelineLatentTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
-):
- pipeline_class = StableDiffusionDiffEditPipeline
- params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"}
- batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - {"image"} | {"image_latents"}
- image_params = frozenset(
- []
- ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
- image_latents_params = frozenset([])
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- inverse_scheduler = DDIMInverseScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_zero=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "inverse_scheduler": inverse_scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- mask = floats_tensor((1, 16, 16), rng=random.Random(seed)).to(device)
- latents = floats_tensor((1, 2, 4, 16, 16), rng=random.Random(seed)).to(device)
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a dog and a newt",
- "mask_image": mask,
- "image_latents": latents,
- "generator": generator,
- "num_inference_steps": 2,
- "inpaint_strength": 1.0,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
-
- return inputs
-
- def get_dummy_mask_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image": image,
- "source_prompt": "a cat and a frog",
- "target_prompt": "a dog and a newt",
- "generator": generator,
- "num_inference_steps": 2,
- "num_maps_per_mask": 2,
- "mask_encode_strength": 1.0,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
-
- return inputs
-
- def get_dummy_inversion_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image": image,
- "prompt": "a cat and a frog",
- "generator": generator,
- "num_inference_steps": 2,
- "inpaint_strength": 1.0,
- "guidance_scale": 6.0,
- "decode_latents": True,
- "output_type": "np",
- }
- return inputs
-
- def test_save_load_optional_components(self):
- if not hasattr(self.pipeline_class, "_optional_components"):
- return
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- # set all optional components to None and update pipeline config accordingly
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
- pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components})
-
- inputs = self.get_dummy_inputs(torch_device)
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
-
- inputs = self.get_dummy_inputs(torch_device)
- output_loaded = pipe_loaded(**inputs)[0]
-
- max_diff = np.abs(output - output_loaded).max()
- self.assertLess(max_diff, 1e-4)
-
- def test_mask(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_mask_inputs(device)
- mask = pipe.generate_mask(**inputs)
- mask_slice = mask[0, -3:, -3:]
-
- self.assertEqual(mask.shape, (1, 16, 16))
- expected_slice = np.array([0] * 9)
- max_diff = np.abs(mask_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
- self.assertEqual(mask[0, -3, -4], 0)
-
- def test_inversion(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inversion_inputs(device)
- image = pipe.invert(**inputs).images
- image_slice = image[0, -1, -3:, -3:]
-
- self.assertEqual(image.shape, (2, 32, 32, 3))
- expected_slice = np.array(
- [0.5160, 0.5115, 0.5060, 0.5456, 0.4704, 0.5060, 0.5019, 0.4405, 0.4726],
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=5e-3)
-
- def test_inversion_dpm(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- scheduler_args = {"beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear"}
- components["scheduler"] = DPMSolverMultistepScheduler(**scheduler_args)
- components["inverse_scheduler"] = DPMSolverMultistepInverseScheduler(**scheduler_args)
-
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inversion_inputs(device)
- image = pipe.invert(**inputs).images
- image_slice = image[0, -1, -3:, -3:]
-
- self.assertEqual(image.shape, (2, 32, 32, 3))
- expected_slice = np.array(
- [0.5305, 0.4673, 0.5314, 0.5308, 0.4886, 0.5279, 0.5142, 0.4724, 0.4892],
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@require_torch_accelerator
-@nightly
-class StableDiffusionDiffEditPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @classmethod
- def setUpClass(cls):
- raw_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/diffedit/fruit.png"
- )
- raw_image = raw_image.convert("RGB").resize((256, 256))
-
- cls.raw_image = raw_image
-
- def test_stable_diffusion_diffedit_full(self):
- generator = torch.manual_seed(0)
-
- pipe = StableDiffusionDiffEditPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.scheduler.clip_sample = True
-
- pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- source_prompt = "a bowl of fruit"
- target_prompt = "a bowl of pears"
-
- mask_image = pipe.generate_mask(
- image=self.raw_image,
- source_prompt=source_prompt,
- target_prompt=target_prompt,
- generator=generator,
- )
-
- inv_latents = pipe.invert(
- prompt=source_prompt,
- image=self.raw_image,
- inpaint_strength=0.7,
- generator=generator,
- num_inference_steps=5,
- ).latents
-
- image = pipe(
- prompt=target_prompt,
- mask_image=mask_image,
- image_latents=inv_latents,
- generator=generator,
- negative_prompt=source_prompt,
- inpaint_strength=0.7,
- num_inference_steps=5,
- output_type="np",
- ).images[0]
-
- expected_image = (
- np.array(
- load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/diffedit/pears.png"
- ).resize((256, 256))
- )
- / 255
- )
-
- assert numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) < 2e-1
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionDiffEditPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @classmethod
- def setUpClass(cls):
- raw_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/diffedit/fruit.png"
- )
-
- raw_image = raw_image.convert("RGB").resize((768, 768))
-
- cls.raw_image = raw_image
-
- def test_stable_diffusion_diffedit_dpm(self):
- generator = torch.manual_seed(0)
-
- pipe = StableDiffusionDiffEditPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
- pipe.inverse_scheduler = DPMSolverMultistepInverseScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
- pipe.set_progress_bar_config(disable=None)
-
- source_prompt = "a bowl of fruit"
- target_prompt = "a bowl of pears"
-
- mask_image = pipe.generate_mask(
- image=self.raw_image,
- source_prompt=source_prompt,
- target_prompt=target_prompt,
- generator=generator,
- )
-
- inv_latents = pipe.invert(
- prompt=source_prompt,
- image=self.raw_image,
- inpaint_strength=0.7,
- generator=generator,
- num_inference_steps=25,
- ).latents
-
- image = pipe(
- prompt=target_prompt,
- mask_image=mask_image,
- image_latents=inv_latents,
- generator=generator,
- negative_prompt=source_prompt,
- inpaint_strength=0.7,
- num_inference_steps=25,
- output_type="np",
- ).images[0]
-
- expected_image = (
- np.array(
- load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/diffedit/pears.png"
- ).resize((768, 768))
- )
- / 255
- )
- assert np.abs((expected_image - image).max()) < 5e-1
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
deleted file mode 100644
index 9e4fa767085f..000000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import nightly, require_flax
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard
-
-
-@nightly
-@require_flax
-class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def test_stable_diffusion_flax(self):
- sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2",
- variant="bf16",
- dtype=jnp.bfloat16,
- )
-
- prompt = "A painting of a squirrel eating a burger"
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = sd_pipe.prepare_inputs(prompt)
-
- params = replicate(params)
- prompt_ids = shard(prompt_ids)
-
- prng_seed = jax.random.PRNGKey(0)
- prng_seed = jax.random.split(prng_seed, jax.device_count())
-
- images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
- assert images.shape == (jax.device_count(), 1, 768, 768, 3)
-
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
-
-
-@nightly
-@require_flax
-class FlaxStableDiffusion2PipelineNightlyTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def test_stable_diffusion_dpm_flax(self):
- model_id = "stabilityai/stable-diffusion-2"
- scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
- sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
- model_id,
- scheduler=scheduler,
- variant="bf16",
- dtype=jnp.bfloat16,
- )
- params["scheduler"] = scheduler_params
-
- prompt = "A painting of a squirrel eating a burger"
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = sd_pipe.prepare_inputs(prompt)
-
- params = replicate(params)
- prompt_ids = shard(prompt_ids)
-
- prng_seed = jax.random.PRNGKey(0)
- prng_seed = jax.random.split(prng_seed, jax.device_count())
-
- images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
- assert images.shape == (jax.device_count(), 1, 768, 768, 3)
-
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
deleted file mode 100644
index eeec52dab51d..000000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-from diffusers import FlaxStableDiffusionInpaintPipeline
-from diffusers.utils import is_flax_available, load_image
-from diffusers.utils.testing_utils import require_flax, slow
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard
-
-
-@slow
-@require_flax
-class FlaxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def test_stable_diffusion_inpaint_pipeline(self):
- init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/sd2-inpaint/init_image.png"
- )
- mask_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
- )
-
- model_id = "xvjiarui/stable-diffusion-2-inpainting"
- pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
-
- prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- init_image = num_samples * [init_image]
- mask_image = num_samples * [mask_image]
- prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, jax.device_count())
- prompt_ids = shard(prompt_ids)
- processed_masked_images = shard(processed_masked_images)
- processed_masks = shard(processed_masks)
-
- output = pipeline(
- prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
- )
-
- images = output.images.reshape(num_samples, 512, 512, 3)
-
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array(
- [0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084]
- )
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
index 2feeaaf11c12..f010c1b03fe3 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,8 +23,10 @@
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
+ backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
@@ -35,7 +37,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -287,6 +288,6 @@ def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
output_type="np",
)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.65 GB is allocated
assert mem_bytes < 2.65 * 10**9
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
index 22e588a9327b..285c2fea7ebc 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,17 +30,18 @@
UNet2DConditionModel,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
+ require_accelerator,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
@@ -222,6 +223,7 @@ def test_stable_diffusion_latent_upscaler_multiple_init_images(self):
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=7e-3)
+ @require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=3e-3)
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
index 5400c21c9f87..481ac7f2d10f 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +24,8 @@
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
index 1953017c0ee8..37b309c4cac4 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,7 +30,8 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index 38ef6143f4c0..3ccefe3de35d 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -2,19 +2,18 @@
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_big_accelerator,
slow,
torch_device,
)
-
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -125,37 +124,22 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
- def test_stable_diffusion_3_different_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "a different prompt"
- inputs["prompt_3"] = "another different prompt"
- output_different_prompts = pipe(**inputs).images[0]
-
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
-
- def test_stable_diffusion_3_different_negative_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
+ def test_inference(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt_2"] = "deformed"
- inputs["negative_prompt_3"] = "blurry"
- output_different_prompts = pipe(**inputs).images[0]
+ image = pipe(**inputs).images[0]
+ generated_slice = image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+ # fmt: off
+ expected_slice = np.array([0.5112, 0.5228, 0.5235, 0.5524, 0.3188, 0.5017, 0.5574, 0.4899, 0.6812, 0.5991, 0.3908, 0.5213, 0.5582, 0.4457, 0.4204, 0.5616])
+ # fmt: on
- # Outputs should be different here
- assert max_diff > 1e-2
+ self.assertTrue(
+ np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
+ )
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -171,9 +155,9 @@ def test_fused_qkv_projections(self):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(
- pipe.transformer
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_processors_exist(pipe.transformer), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
@@ -187,15 +171,15 @@ def test_fused_qkv_projections(self):
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
def test_skip_guidance_layers(self):
components = self.get_dummy_components()
@@ -233,7 +217,6 @@ def test_skip_guidance_layers(self):
@slow
@require_big_accelerator
-@pytest.mark.big_gpu_with_torch_cuda
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
@@ -270,40 +253,9 @@ def test_sd3_inference(self):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- 0.4648,
- 0.4404,
- 0.4177,
- 0.5063,
- 0.4800,
- 0.4287,
- 0.5425,
- 0.5190,
- 0.4717,
- 0.5430,
- 0.5195,
- 0.4766,
- 0.5361,
- 0.5122,
- 0.4612,
- 0.4871,
- 0.4749,
- 0.4058,
- 0.4756,
- 0.4678,
- 0.3804,
- 0.4832,
- 0.4822,
- 0.3799,
- 0.5103,
- 0.5034,
- 0.3953,
- 0.5073,
- 0.4839,
- 0.3884,
- ]
- )
+ # fmt: off
+ expected_slice = np.array([0.4648, 0.4404, 0.4177, 0.5063, 0.4800, 0.4287, 0.5425, 0.5190, 0.4717, 0.5430, 0.5195, 0.4766, 0.5361, 0.5122, 0.4612, 0.4871, 0.4749, 0.4058, 0.4756, 0.4678, 0.3804, 0.4832, 0.4822, 0.3799, 0.5103, 0.5034, 0.3953, 0.5073, 0.4839, 0.3884])
+ # fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
index f7c450aab93e..9025b1060c9e 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
@@ -3,7 +3,6 @@
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -14,7 +13,9 @@
StableDiffusion3Img2ImgPipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
floats_tensor,
numpy_cosine_similarity_distance,
@@ -22,7 +23,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -128,37 +128,22 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
- def test_stable_diffusion_3_img2img_different_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
+ def test_inference(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "a different prompt"
- inputs["prompt_3"] = "another different prompt"
- output_different_prompts = pipe(**inputs).images[0]
-
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
-
- def test_stable_diffusion_3_img2img_different_negative_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt_2"] = "deformed"
- inputs["negative_prompt_3"] = "blurry"
- output_different_prompts = pipe(**inputs).images[0]
+ image = pipe(**inputs).images[0]
+ generated_slice = image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+ # fmt: off
+ expected_slice = np.array([0.4564, 0.5486, 0.4868, 0.5923, 0.3775, 0.5543, 0.4807, 0.4177, 0.3778, 0.5957, 0.5726, 0.4333, 0.6312, 0.5062, 0.4838, 0.5984])
+ # fmt: on
- # Outputs should be different here
- assert max_diff > 1e-2
+ self.assertTrue(
+ np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
+ )
@unittest.skip("Skip for now.")
def test_multi_vae(self):
@@ -167,7 +152,6 @@ def test_multi_vae(self):
@slow
@require_big_accelerator
-@pytest.mark.big_gpu_with_torch_cuda
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
@@ -208,40 +192,18 @@ def test_sd3_img2img_inference(self):
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- 0.5435,
- 0.4673,
- 0.5732,
- 0.4438,
- 0.3557,
- 0.4912,
- 0.4331,
- 0.3491,
- 0.4915,
- 0.4287,
- 0.3477,
- 0.4849,
- 0.4355,
- 0.3469,
- 0.4871,
- 0.4431,
- 0.3538,
- 0.4912,
- 0.4521,
- 0.3643,
- 0.5059,
- 0.4587,
- 0.3730,
- 0.5166,
- 0.4685,
- 0.3845,
- 0.5264,
- 0.4746,
- 0.3914,
- 0.5342,
- ]
+
+ # fmt: off
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array([0.5117, 0.4421, 0.3852, 0.5044, 0.4219, 0.3262, 0.5024, 0.4329, 0.3276, 0.4978, 0.4412, 0.3355, 0.4983, 0.4338, 0.3279, 0.4893, 0.4241, 0.3129, 0.4875, 0.4253, 0.3030, 0.4961, 0.4267, 0.2988, 0.5029, 0.4255, 0.3054, 0.5132, 0.4248, 0.3222]),
+ ("cuda", 7): np.array([0.5435, 0.4673, 0.5732, 0.4438, 0.3557, 0.4912, 0.4331, 0.3491, 0.4915, 0.4287, 0.347, 0.4849, 0.4355, 0.3469, 0.4871, 0.4431, 0.3538, 0.4912, 0.4521, 0.3643, 0.5059, 0.4587, 0.373, 0.5166, 0.4685, 0.3845, 0.5264, 0.4746, 0.3914, 0.5342]),
+ ("cuda", 8): np.array([0.5146, 0.4385, 0.3826, 0.5098, 0.4150, 0.3218, 0.5142, 0.4312, 0.3298, 0.5127, 0.4431, 0.3411, 0.5171, 0.4424, 0.3374, 0.5088, 0.4348, 0.3242, 0.5073, 0.4380, 0.3174, 0.5132, 0.4397, 0.3115, 0.5132, 0.4343, 0.3118, 0.5219, 0.4328, 0.3256]),
+ }
)
+ # fmt: on
+
+ expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
index 4090306dec72..628930340294 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
@@ -11,12 +11,12 @@
SD3Transformer2DModel,
StableDiffusion3InpaintPipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -132,37 +132,23 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs
- def test_stable_diffusion_3_inpaint_different_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "a different prompt"
- inputs["prompt_3"] = "another different prompt"
- output_different_prompts = pipe(**inputs).images[0]
-
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
-
- def test_stable_diffusion_3_inpaint_different_negative_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ def test_inference(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
+ image = pipe(**inputs).images[0]
+ generated_slice = image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt_2"] = "deformed"
- inputs["negative_prompt_3"] = "blurry"
- output_different_prompts = pipe(**inputs).images[0]
-
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+ # fmt: off
+ expected_slice = np.array([0.5035, 0.6661, 0.5859, 0.413, 0.4224, 0.4234, 0.7181, 0.5062, 0.5183, 0.6877, 0.5074, 0.585, 0.6111, 0.5422, 0.5306, 0.5891])
+ # fmt: on
- # Outputs should be different here
- assert max_diff > 1e-2
+ self.assertTrue(
+ np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
+ )
+ @unittest.skip("Skip for now.")
def test_multi_vae(self):
pass
diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
index 009c75df4249..79b38d1cad1c 100644
--- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
+++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
@@ -34,7 +34,8 @@
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -45,7 +46,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference
diff --git a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py b/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
deleted file mode 100644
index b3ac507f768e..000000000000
--- a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- StableDiffusionGLIGENPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism
-
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class GligenPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionGLIGENPipeline
- params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_boxes"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- attention_type="gated",
- )
- # unet.position_net = PositionNet(32,32)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A modern livingroom",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "gligen_phrases": ["a birthday cake"],
- "gligen_boxes": [[0.2676, 0.6088, 0.4773, 0.7183]],
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_gligen_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5069, 0.5561, 0.4577, 0.4792, 0.5203, 0.4089, 0.5039, 0.4919, 0.4499])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_gligen_k_euler_ancestral(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENPipeline(**components)
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.425, 0.494, 0.429, 0.469, 0.525, 0.417, 0.533, 0.5, 0.47])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_attention_slicing_forward_pass(self):
- super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
-
- @unittest.skip("Test not supported as tokenizer is used for parsing bounding boxes.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
deleted file mode 100644
index b080bb987e13..000000000000
--- a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- CLIPProcessor,
- CLIPTextConfig,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- StableDiffusionGLIGENTextImagePipeline,
- UNet2DConditionModel,
-)
-from diffusers.pipelines.stable_diffusion import CLIPImageProjection
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class GligenTextImagePipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionGLIGENTextImagePipeline
- params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_images", "gligen_boxes"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- attention_type="gated-text-image",
- )
- # unet.position_net = PositionNet(32,32)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- image_encoder_config = CLIPVisionConfig(
- hidden_size=32,
- projection_dim=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- )
- image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
-
- image_project = CLIPImageProjection(hidden_size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": image_encoder,
- "image_project": image_project,
- "processor": processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- gligen_images = load_image(
- "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png"
- )
- inputs = {
- "prompt": "A modern livingroom",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "gligen_phrases": ["a birthday cake"],
- "gligen_images": [gligen_images],
- "gligen_boxes": [[0.2676, 0.6088, 0.4773, 0.7183]],
- "output_type": "np",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.5052, 0.5546, 0.4567, 0.4770, 0.5195, 0.4085, 0.5026, 0.4909, 0.4495])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_stable_diffusion_gligen_text_image_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5069, 0.5561, 0.4577, 0.4792, 0.5203, 0.4089, 0.5039, 0.4919, 0.4499])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_gligen_k_euler_ancestral(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.425, 0.494, 0.429, 0.469, 0.525, 0.417, 0.533, 0.5, 0.47])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_attention_slicing_forward_pass(self):
- super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
-
- @unittest.skip(
- "Test not supported because of the use of `text_encoder` in `get_cross_attention_kwargs_with_grounded()`."
- )
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
index f706e7000b28..dbf5a7b68eae 100644
--- a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
+++ b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,7 +29,8 @@
StableDiffusionImageVariationPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -44,7 +45,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py b/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py
deleted file mode 100644
index fe78f0ec3c1a..000000000000
--- a/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import StableDiffusionKDiffusionPipeline
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
-
-
-enable_full_determinism()
-
-
-@nightly
-@require_torch_gpu
-class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_stable_diffusion_1(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0447, 0.0492, 0.0468, 0.0408, 0.0383, 0.0408, 0.0354, 0.0380, 0.0339])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_2(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1
-
- def test_stable_diffusion_karras_sigmas(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_2m")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=7.5,
- num_inference_steps=15,
- output_type="np",
- use_karras_sigmas=True,
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array(
- [0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_noise_sampler_seed(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_sde")
-
- prompt = "A painting of a squirrel eating a burger"
- seed = 0
- images1 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=20,
- output_type="np",
- ).images
- images2 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=20,
- output_type="np",
- ).images
-
- assert images1.shape == (1, 512, 512, 3)
- assert images2.shape == (1, 512, 512, 3)
- assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py b/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py
deleted file mode 100644
index 8f07d02aad5e..000000000000
--- a/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py
+++ /dev/null
@@ -1,320 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- PNDMScheduler,
- StableDiffusionLDM3DPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-
-
-enable_full_determinism()
-
-
-class StableDiffusionLDM3DPipelineFastTests(unittest.TestCase):
- pipeline_class = StableDiffusionLDM3DPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=6,
- out_channels=6,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- image_slice_rgb = rgb[0, -3:, -3:, -1]
- image_slice_depth = depth[0, -3:, -1]
-
- assert rgb.shape == (1, 64, 64, 3)
- assert depth.shape == (1, 64, 64)
-
- expected_slice_rgb = np.array(
- [0.37338176, 0.70247, 0.74203193, 0.51643604, 0.58256793, 0.60932136, 0.4181095, 0.48355877, 0.46535262]
- )
- expected_slice_depth = np.array([103.46727, 85.812004, 87.849236])
-
- assert np.abs(image_slice_rgb.flatten() - expected_slice_rgb).max() < 1e-2
- assert np.abs(image_slice_depth.flatten() - expected_slice_depth).max() < 1e-2
-
- def test_stable_diffusion_prompt_embeds(self):
- components = self.get_dummy_components()
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = ldm3d_pipe(**inputs)
- rgb_slice_1, depth_slice_1 = output.rgb, output.depth
- rgb_slice_1 = rgb_slice_1[0, -3:, -3:, -1]
- depth_slice_1 = depth_slice_1[0, -3:, -1]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = ldm3d_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=ldm3d_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = ldm3d_pipe.text_encoder(text_inputs)[0]
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = ldm3d_pipe(**inputs)
- rgb_slice_2, depth_slice_2 = output.rgb, output.depth
- rgb_slice_2 = rgb_slice_2[0, -3:, -3:, -1]
- depth_slice_2 = depth_slice_2[0, -3:, -1]
-
- assert np.abs(rgb_slice_1.flatten() - rgb_slice_2.flatten()).max() < 1e-4
- assert np.abs(depth_slice_1.flatten() - depth_slice_2.flatten()).max() < 1e-4
-
- def test_stable_diffusion_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "french fries"
- output = ldm3d_pipe(**inputs, negative_prompt=negative_prompt)
-
- rgb, depth = output.rgb, output.depth
- rgb_slice = rgb[0, -3:, -3:, -1]
- depth_slice = depth[0, -3:, -1]
-
- assert rgb.shape == (1, 64, 64, 3)
- assert depth.shape == (1, 64, 64)
-
- expected_slice_rgb = np.array(
- [0.37044, 0.71811503, 0.7223251, 0.48603675, 0.5638391, 0.6364948, 0.42833704, 0.4901315, 0.47926217]
- )
- expected_slice_depth = np.array([107.84738, 84.62802, 89.962135])
- assert np.abs(rgb_slice.flatten() - expected_slice_rgb).max() < 1e-2
- assert np.abs(depth_slice.flatten() - expected_slice_depth).max() < 1e-2
-
-
-@nightly
-@require_torch_gpu
-class StableDiffusionLDM3DPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "a photograph of an astronaut riding a horse",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_ldm3d_stable_diffusion(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d")
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
- rgb_slice = rgb[0, -3:, -3:, -1].flatten()
- depth_slice = rgb[0, -3:, -1].flatten()
-
- assert rgb.shape == (1, 512, 512, 3)
- assert depth.shape == (1, 512, 512)
-
- expected_slice_rgb = np.array(
- [0.53805465, 0.56707305, 0.5486515, 0.57012236, 0.5814511, 0.56253487, 0.54843014, 0.55092263, 0.6459706]
- )
- expected_slice_depth = np.array(
- [0.9263781, 0.6678672, 0.5486515, 0.92202145, 0.67831135, 0.56253487, 0.9241694, 0.7551478, 0.6459706]
- )
- assert np.abs(rgb_slice - expected_slice_rgb).max() < 3e-3
- assert np.abs(depth_slice - expected_slice_depth).max() < 3e-3
-
-
-@nightly
-@require_torch_gpu
-class StableDiffusionPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "a photograph of an astronaut riding a horse",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 50,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_ldm3d(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d").to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- expected_rgb_mean = 0.495586
- expected_rgb_std = 0.33795515
- expected_depth_mean = 112.48518
- expected_depth_std = 98.489746
- assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
- assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
- assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
- assert np.abs(expected_depth_std - depth.std()) < 1e-3
-
- def test_ldm3d_v2(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c").to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- expected_rgb_mean = 0.4194127
- expected_rgb_std = 0.35375586
- expected_depth_mean = 0.5638502
- expected_depth_std = 0.34686103
-
- assert rgb.shape == (1, 512, 512, 3)
- assert depth.shape == (1, 512, 512, 1)
- assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
- assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
- assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
- assert np.abs(expected_depth_std - depth.std()) < 1e-3
diff --git a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
deleted file mode 100644
index 4734af259921..000000000000
--- a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
+++ /dev/null
@@ -1,434 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
- StableDiffusionPanoramaPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- IPAdapterTesterMixin,
- PipelineFromPipeTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class StableDiffusionPanoramaPipelineFastTests(
- IPAdapterTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionPanoramaPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = DDIMScheduler()
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a photo of the dolomites",
- "generator": generator,
- # Setting height and width to None to prevent OOMs on CPU.
- "height": None,
- "width": None,
- "num_inference_steps": 1,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_panorama_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6186, 0.5374, 0.4915, 0.4135, 0.4114, 0.4563, 0.5128, 0.4977, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_circular_padding_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs, circular_padding=True).images
- image_slice = image[0, -3:, -3:, -1]
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # override to speed the overall test timing up.
- def test_inference_batch_consistent(self):
- super().test_inference_batch_consistent(batch_sizes=[1, 2])
-
- # override to speed the overall test timing up.
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=5.0e-3)
-
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1e-1)
-
- def test_stable_diffusion_panorama_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "french fries"
- output = sd_pipe(**inputs, negative_prompt=negative_prompt)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_views_batch(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs, view_batch_size=2)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_views_batch_circular_padding(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs, circular_padding=True, view_batch_size=2)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_euler(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = EulerAncestralDiscreteScheduler(
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
- )
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.4024, 0.6510, 0.4901, 0.5378, 0.5813, 0.5622, 0.4795, 0.4467, 0.4952])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
- )
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@require_torch_gpu
-class StableDiffusionPanoramaNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, seed=0):
- generator = torch.manual_seed(seed)
- inputs = {
- "prompt": "a photo of the dolomites",
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_panorama_default(self):
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 2048, 3)
-
- expected_slice = np.array(
- [
- 0.36968392,
- 0.27025372,
- 0.32446766,
- 0.28379387,
- 0.36363274,
- 0.30733347,
- 0.27100027,
- 0.27054125,
- 0.25536096,
- ]
- )
-
- assert np.abs(expected_slice - image_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_k_lms(self):
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-base", safety_checker=None
- )
- pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
- pipe.unet.set_default_attn_processor()
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 512, 2048, 3)
-
- expected_slice = np.array(
- [
- [
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- ]
- ]
- )
-
- assert np.abs(expected_slice - image_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_intermediate_state(self):
- number_of_steps = 0
-
- def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
- callback_fn.has_been_called = True
- nonlocal number_of_steps
- number_of_steps += 1
- if step == 1:
- latents = latents.detach().cpu().numpy()
- assert latents.shape == (1, 4, 64, 256)
- latents_slice = latents[0, -3:, -3:, -1]
-
- expected_slice = np.array(
- [
- 0.18681869,
- 0.33907816,
- 0.5361276,
- 0.14432865,
- -0.02856611,
- -0.73941123,
- 0.23397987,
- 0.47322682,
- -0.37823164,
- ]
- )
- assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
- elif step == 2:
- latents = latents.detach().cpu().numpy()
- assert latents.shape == (1, 4, 64, 256)
- latents_slice = latents[0, -3:, -3:, -1]
-
- expected_slice = np.array(
- [
- 0.18539645,
- 0.33987248,
- 0.5378559,
- 0.14437142,
- -0.02455261,
- -0.7338317,
- 0.23990755,
- 0.47356272,
- -0.3786505,
- ]
- )
-
- assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
-
- callback_fn.has_been_called = False
-
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- pipe(**inputs, callback=callback_fn, callback_steps=1)
- assert callback_fn.has_been_called
- assert number_of_steps == 3
-
- def test_stable_diffusion_panorama_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
-
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
-
- inputs = self.get_inputs()
- _ = pipe(**inputs)
-
- mem_bytes = torch.cuda.max_memory_allocated()
- # make sure that less than 5.2 GB is allocated
- assert mem_bytes < 5.5 * 10**9
diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
deleted file mode 100644
index 269677c08345..000000000000
--- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
+++ /dev/null
@@ -1,451 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
-from diffusers.utils.testing_utils import floats_tensor, nightly, require_accelerator, require_torch_gpu, torch_device
-
-
-class SafeDiffusionPipelineFastTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- @property
- def dummy_image(self):
- batch_size = 1
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
- return image
-
- @property
- def dummy_cond_unet(self):
- torch.manual_seed(0)
- model = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- return model
-
- @property
- def dummy_vae(self):
- torch.manual_seed(0)
- model = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- return model
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config)
-
- @property
- def dummy_extractor(self):
- def extract(*args, **kwargs):
- class Out:
- def __init__(self):
- self.pixel_values = torch.ones([0])
-
- def to(self, device):
- self.pixel_values.to(device)
- return self
-
- return Out()
-
- return extract
-
- def test_safe_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5756, 0.6118, 0.5005, 0.5041, 0.5471, 0.4726, 0.4976, 0.4865, 0.4864])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
-
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5125, 0.5716, 0.4828, 0.5060, 0.5650, 0.4768, 0.5185, 0.4895, 0.4993])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_no_safety_checker(self):
- pipe = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
- )
- assert isinstance(pipe, StableDiffusionPipeline)
- assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
- assert pipe.safety_checker is None
-
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- # check that there's no error when saving a pipeline with one of the models being None
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
- pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
-
- # sanity check that the pipeline still works
- assert pipe.safety_checker is None
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- @require_accelerator
- def test_stable_diffusion_fp16(self):
- """Test that stable diffusion works with fp16"""
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # put models in fp16
- unet = unet.half()
- vae = vae.half()
- bert = bert.half()
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images
-
- assert image.shape == (1, 64, 64, 3)
-
-
-@nightly
-@require_torch_gpu
-class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_harm_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None
- )
- sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = (
- "portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle"
- " coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with"
- " anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and"
- " children from bahnhof zoo, detailed "
- )
- seed = 4003660346
- guidance_scale = 7
-
- # without safety guidance (sld_guidance_scale = 0)
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # without safety guidance (strong configuration)
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_nudity_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None
- )
- sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "padme amidala taking a bath artwork, safe for work, no nudity"
- seed = 2734971755
- guidance_scale = 7
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_nudity_safetychecker_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = (
- "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c."
- " leyendecker"
- )
- seed = 1044355234
- guidance_scale = 12
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561])
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
deleted file mode 100644
index bd1ba268d2d9..000000000000
--- a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
+++ /dev/null
@@ -1,239 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- DEISMultistepScheduler,
- DPMSolverMultistepScheduler,
- EulerDiscreteScheduler,
- StableDiffusionSAGPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- IPAdapterTesterMixin,
- PipelineFromPipeTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class StableDiffusionSAGPipelineFastTests(
- IPAdapterTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionSAGPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=8,
- norm_num_groups=1,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- norm_num_groups=1,
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- num_hidden_layers=2,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": ".",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 1.0,
- "sag_scale": 1.0,
- "output_type": "np",
- }
- return inputs
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
- @unittest.skip("Not necessary to test here.")
- def test_xformers_attention_forwardGenerator_pass(self):
- pass
-
- def test_pipeline_different_schedulers(self):
- pipeline = self.pipeline_class(**self.get_dummy_components())
- inputs = self.get_dummy_inputs("cpu")
-
- expected_image_size = (16, 16, 3)
- for scheduler_cls in [DDIMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler]:
- pipeline.scheduler = scheduler_cls.from_config(pipeline.scheduler.config)
- image = pipeline(**inputs).images[0]
-
- shape = image.shape
- assert shape == expected_image_size
-
- pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
-
- with self.assertRaises(ValueError):
- # Karras schedulers are not supported
- image = pipeline(**inputs).images[0]
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@require_torch_gpu
-class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_stable_diffusion_1(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1568, 0.1738, 0.1695, 0.1693, 0.1507, 0.1705, 0.1547, 0.1751, 0.1949])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
-
- def test_stable_diffusion_2(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
-
- def test_stable_diffusion_2_non_square(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt],
- width=768,
- height=512,
- generator=generator,
- guidance_scale=7.5,
- sag_scale=1.0,
- num_inference_steps=20,
- output_type="np",
- )
-
- image = output.images
-
- assert image.shape == (1, 512, 768, 3)
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index c68cdf67036a..b318a505e9db 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -34,7 +34,9 @@
UNet2DConditionModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
load_image,
numpy_cosine_similarity_distance,
@@ -42,7 +44,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
@@ -242,15 +243,15 @@ def test_stable_diffusion_ays(self):
inputs["sigmas"] = sigma_schedule
output_sigmas = sd_pipe(**inputs).images
- assert (
- np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3
- ), "ays timesteps and ays sigmas should have the same outputs"
- assert (
- np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3
- ), "use ays timesteps should have different outputs"
- assert (
- np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3
- ), "use ays sigmas should have different outputs"
+ assert np.abs(output_sigmas.flatten() - output_ts.flatten()).max() < 1e-3, (
+ "ays timesteps and ays sigmas should have the same outputs"
+ )
+ assert np.abs(output.flatten() - output_ts.flatten()).max() > 1e-3, (
+ "use ays timesteps should have different outputs"
+ )
+ assert np.abs(output.flatten() - output_sigmas.flatten()).max() > 1e-3, (
+ "use ays sigmas should have different outputs"
+ )
def test_ip_adapter(self):
expected_pipe_slice = None
@@ -742,9 +743,9 @@ def new_step(self, *args, **kwargs):
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
latents = pipe_1(**inputs_1).images[0]
- assert (
- expected_steps_1 == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps_1 == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
with self.assertRaises(ValueError) as cm:
inputs_2 = {
@@ -771,9 +772,9 @@ def new_step(self, *args, **kwargs):
pipe_3(**inputs_3).images[0]
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
- assert (
- expected_steps == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
@@ -940,12 +941,12 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_diffusion_lcm(self):
torch.manual_seed(0)
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
index 07333623867e..3d72270dda5c 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -32,12 +32,12 @@
UNet2DConditionModel,
)
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
index 9a141634a364..c5499847069f 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -38,7 +38,9 @@
StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
@@ -46,7 +48,6 @@
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -670,12 +671,12 @@ class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_diffusion_xl_img2img_playground(self):
torch.manual_seed(0)
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
index 66ae581a0529..d3f5779c7633 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -41,14 +41,14 @@
UNet2DConditionModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -585,9 +585,9 @@ def new_step(self, *args, **kwargs):
inputs_1 = {**inputs, **{"denoising_end": split_1, "output_type": "latent"}}
latents = pipe_1(**inputs_1).images[0]
- assert (
- expected_steps_1 == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps_1 == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
inputs_2 = {
**inputs,
@@ -601,9 +601,9 @@ def new_step(self, *args, **kwargs):
pipe_3(**inputs_3).images[0]
assert expected_steps_3 == done_steps[len(expected_steps_1) + len(expected_steps_2) :]
- assert (
- expected_steps == done_steps
- ), f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ assert expected_steps == done_steps, (
+ f"Failure with {scheduler_cls.__name__} and {num_steps} and {split_1} and {split_2}"
+ )
for steps in [7, 11, 20]:
for split_1, split_2 in zip([0.19, 0.32], [0.81, 0.68]):
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
index 79d38c4a7b43..20a03583e7a9 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 Harutatsu Akiyama and HuggingFace Inc.
+# Copyright 2025 Harutatsu Akiyama and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,8 +29,8 @@
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import (
StableDiffusionXLInstructPix2PixPipeline,
)
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
deleted file mode 100644
index 46f7d0e7b0b4..000000000000
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
+++ /dev/null
@@ -1,146 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import StableDiffusionXLKDiffusionPipeline
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-
-enable_full_determinism()
-
-
-@slow
-@require_torch_accelerator
-class StableDiffusionXLKPipelineIntegrationTests(unittest.TestCase):
- dtype = torch.float16
-
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_stable_diffusion_xl(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=9.0,
- num_inference_steps=2,
- height=512,
- width=512,
- output_type="np",
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.5420, 0.5038, 0.2439, 0.5371, 0.4660, 0.1906, 0.5221, 0.4290, 0.2566])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_karras_sigmas(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_2m")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=7.5,
- num_inference_steps=2,
- output_type="np",
- use_karras_sigmas=True,
- height=512,
- width=512,
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.6418, 0.6424, 0.6462, 0.6271, 0.6314, 0.6295, 0.6249, 0.6339, 0.6335])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_noise_sampler_seed(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_sde")
-
- prompt = "A painting of a squirrel eating a burger"
- seed = 0
- images1 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=2,
- output_type="np",
- height=512,
- width=512,
- ).images
- images2 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=2,
- output_type="np",
- height=512,
- width=512,
- ).images
- assert images1.shape == (1, 512, 512, 3)
- assert images2.shape == (1, 512, 512, 3)
- assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2
diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py
index 8cf103dffd56..8923c2f63cee 100644
--- a/tests/pipelines/stable_unclip/test_stable_unclip.py
+++ b/tests/pipelines/stable_unclip/test_stable_unclip.py
@@ -13,8 +13,18 @@
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
-from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
+from ...testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ enable_full_determinism,
+ load_numpy,
+ nightly,
+ require_torch_accelerator,
+ torch_device,
+)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
@@ -190,19 +200,19 @@ def test_encode_prompt_works_in_isolation(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableUnCLIPPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_unclip(self):
expected_image = load_numpy(
@@ -217,7 +227,7 @@ def test_stable_unclip(self):
pipe.enable_sequential_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipe("anime turle", generator=generator, output_type="np")
+ output = pipe("anime turtle", generator=generator, output_type="np")
image = output.images[0]
@@ -226,9 +236,9 @@ def test_stable_unclip(self):
assert_mean_pixel_difference(image, expected_image)
def test_stable_unclip_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16)
pipe.set_progress_bar_config(disable=None)
@@ -242,6 +252,6 @@ def test_stable_unclip_pipeline_with_sequential_cpu_offloading(self):
output_type="np",
)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 7 GB is allocated
assert mem_bytes < 7 * 10**9
diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
index 176b6954d616..e7a0fbccef67 100644
--- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
+++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
@@ -17,17 +17,21 @@
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_max_memory_allocated,
+ backend_reset_peak_memory_stats,
enable_full_determinism,
floats_tensor,
load_image,
load_numpy,
nightly,
- require_torch_gpu,
+ require_torch_accelerator,
skip_mps,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
@@ -213,19 +217,19 @@ def test_encode_prompt_works_in_isolation(self):
@nightly
-@require_torch_gpu
+@require_torch_accelerator
class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_stable_unclip_l_img2img(self):
input_image = load_image(
@@ -246,7 +250,7 @@ def test_stable_unclip_l_img2img(self):
pipe.enable_sequential_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipe(input_image, "anime turle", generator=generator, output_type="np")
+ output = pipe(input_image, "anime turtle", generator=generator, output_type="np")
image = output.images[0]
@@ -273,7 +277,7 @@ def test_stable_unclip_h_img2img(self):
pipe.enable_sequential_cpu_offload()
generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipe(input_image, "anime turle", generator=generator, output_type="np")
+ output = pipe(input_image, "anime turtle", generator=generator, output_type="np")
image = output.images[0]
@@ -286,9 +290,9 @@ def test_stable_unclip_img2img_pipeline_with_sequential_cpu_offloading(self):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_unclip/turtle.png"
)
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
+ backend_empty_cache(torch_device)
+ backend_reset_max_memory_allocated(torch_device)
+ backend_reset_peak_memory_stats(torch_device)
pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(
"fusing/stable-unclip-2-1-h-img2img", torch_dtype=torch.float16
@@ -304,6 +308,6 @@ def test_stable_unclip_img2img_pipeline_with_sequential_cpu_offloading(self):
output_type="np",
)
- mem_bytes = torch.cuda.max_memory_allocated()
+ mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 7 GB is allocated
assert mem_bytes < 7 * 10**9
diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
index f77a5b1620d2..52595f7a8cd9 100644
--- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
+++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
@@ -20,7 +20,8 @@
)
from diffusers.utils import load_image, logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
enable_full_determinism,
@@ -32,7 +33,6 @@
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py
index 423c2b8ab146..6d9e68197976 100644
--- a/tests/pipelines/test_pipeline_utils.py
+++ b/tests/pipelines/test_pipeline_utils.py
@@ -19,7 +19,8 @@
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
-from diffusers.utils.testing_utils import require_torch_gpu, torch_device
+
+from ..testing_utils import require_torch_accelerator, torch_device
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -87,21 +88,24 @@ def test_all_is_compatible_variant(self):
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
- self.assertTrue(is_safetensors_compatible(filenames))
+ self.assertFalse(is_safetensors_compatible(filenames))
+ self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
- self.assertTrue(is_safetensors_compatible(filenames))
+ self.assertFalse(is_safetensors_compatible(filenames))
+ self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_compatible_variant_mixed(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
- self.assertTrue(is_safetensors_compatible(filenames))
+ self.assertFalse(is_safetensors_compatible(filenames))
+ self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
@@ -121,7 +125,8 @@ def test_transformer_model_is_compatible_variant(self):
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
- self.assertTrue(is_safetensors_compatible(filenames))
+ self.assertFalse(is_safetensors_compatible(filenames))
+ self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_transformer_model_is_not_compatible_variant(self):
filenames = [
@@ -145,7 +150,8 @@ def test_transformer_model_is_compatible_variant_extra_folder(self):
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
- self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
+ self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
+ self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16"))
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [
@@ -173,7 +179,8 @@ def test_transformers_is_compatible_variant_sharded(self):
"text_encoder/model.fp16-00001-of-00002.safetensors",
"text_encoder/model.fp16-00001-of-00002.safetensors",
]
- self.assertTrue(is_safetensors_compatible(filenames))
+ self.assertFalse(is_safetensors_compatible(filenames))
+ self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_sharded(self):
filenames = [
@@ -189,13 +196,15 @@ def test_diffusers_is_compatible_variant_sharded(self):
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
]
- self.assertTrue(is_safetensors_compatible(filenames))
+ self.assertFalse(is_safetensors_compatible(filenames))
+ self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_only_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
]
- self.assertTrue(is_safetensors_compatible(filenames))
+ self.assertFalse(is_safetensors_compatible(filenames))
+ self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_no_components(self):
filenames = [
@@ -209,6 +218,20 @@ def test_diffusers_is_compatible_no_components_only_variants(self):
]
self.assertFalse(is_safetensors_compatible(filenames))
+ def test_is_compatible_mixed_variants(self):
+ filenames = [
+ "unet/diffusion_pytorch_model.fp16.safetensors",
+ "vae/diffusion_pytorch_model.safetensors",
+ ]
+ self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
+
+ def test_is_compatible_variant_and_non_safetensors(self):
+ filenames = [
+ "unet/diffusion_pytorch_model.fp16.safetensors",
+ "vae/diffusion_pytorch_model.bin",
+ ]
+ self.assertFalse(is_safetensors_compatible(filenames, variant="fp16"))
+
class VariantCompatibleSiblingsTest(unittest.TestCase):
def test_only_non_variants_downloaded(self):
@@ -828,9 +851,9 @@ def test_video_to_video(self):
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
-@require_torch_gpu
+@require_torch_accelerator
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
- expected_pipe_device = torch.device("cuda:0")
+ expected_pipe_device = torch.device(f"{torch_device}:0")
expected_pipe_dtype = torch.float64
def get_dummy_components_image_generation(self):
@@ -899,8 +922,8 @@ def test_deterministic_device(self):
pipe.to(device=torch_device, dtype=torch.float32)
pipe.unet.to(device="cpu")
- pipe.vae.to(device="cuda")
- pipe.text_encoder.to(device="cuda:0")
+ pipe.vae.to(device=torch_device)
+ pipe.text_encoder.to(device=f"{torch_device}:0")
pipe_device = pipe.device
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 48c89d399216..a17db3ff0c5a 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,23 +17,26 @@
import json
import os
import random
+import re
import shutil
import sys
import tempfile
import traceback
import unittest
import unittest.mock as mock
+import warnings
import numpy as np
import PIL.Image
+import pytest
import requests_mock
import safetensors.torch
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
+from huggingface_hub.utils import HfHubHTTPError
from parameterized import parameterized
from PIL import Image
-from requests.exceptions import HTTPError
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
@@ -60,11 +63,10 @@
)
from diffusers.pipelines.pipeline_utils import _get_pipeline_class
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
-from diffusers.utils import (
- CONFIG_NAME,
- WEIGHTS_NAME,
-)
-from diffusers.utils.testing_utils import (
+from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, is_transformers_version
+from diffusers.utils.torch_utils import is_compiled_module
+
+from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
enable_full_determinism,
@@ -78,6 +80,8 @@
require_flax,
require_hf_hub_version_greater,
require_onnxruntime,
+ require_peft_backend,
+ require_peft_version_greater,
require_torch_2,
require_torch_accelerator,
require_transformers_version_greater,
@@ -85,7 +89,6 @@
slow,
torch_device,
)
-from diffusers.utils.torch_utils import is_compiled_module
enable_full_determinism()
@@ -163,9 +166,9 @@ def test_one_request_upon_cached(self):
download_requests = [r.method for r in m.request_history]
assert download_requests.count("HEAD") == 15, "15 calls to files"
assert download_requests.count("GET") == 17, "15 calls to files + model_info + model_index.json"
- assert (
- len(download_requests) == 32
- ), "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
+ assert len(download_requests) == 32, (
+ "2 calls per file (15 files) + send_telemetry, model_info and model_index.json"
+ )
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
@@ -175,9 +178,9 @@ def test_one_request_upon_cached(self):
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
- assert (
- len(cache_requests) == 2
- ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ assert len(cache_requests) == 2, (
+ "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ )
def test_less_downloads_passed_object(self):
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -213,9 +216,9 @@ def test_less_downloads_passed_object_calls(self):
assert download_requests.count("HEAD") == 13, "13 calls to files"
# 17 - 2 because no call to config or model file for `safety_checker`
assert download_requests.count("GET") == 15, "13 calls to files + model_info + model_index.json"
- assert (
- len(download_requests) == 28
- ), "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
+ assert len(download_requests) == 28, (
+ "2 calls per file (13 files) + send_telemetry, model_info and model_index.json"
+ )
with requests_mock.mock(real_http=True) as m:
DiffusionPipeline.download(
@@ -225,9 +228,9 @@ def test_less_downloads_passed_object_calls(self):
cache_requests = [r.method for r in m.request_history]
assert cache_requests.count("HEAD") == 1, "model_index.json is only HEAD"
assert cache_requests.count("GET") == 1, "model info is only GET"
- assert (
- len(cache_requests) == 2
- ), "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ assert len(cache_requests) == 2, (
+ "We should call only `model_info` to check for _commit hash and `send_telemetry`"
+ )
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -425,7 +428,7 @@ def test_cached_files_are_used_when_no_internet(self):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
@@ -452,7 +455,7 @@ def test_local_files_only_are_used_when_no_internet(self):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# first check that with local files only the pipeline can only be used if cached
@@ -579,25 +582,23 @@ def test_download_variants_with_sharded_checkpoints(self):
assert not any(f.endswith(unexpected_ext) for f in files)
assert all(variant in f for f in model_files if f.endswith(model_ext) and variant is not None)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
- for is_local in [True, False]:
- with CaptureLogger(logger) as cap_logger:
- with tempfile.TemporaryDirectory() as tmpdirname:
- local_repo_id = repo_id
- if is_local:
- local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
+ with CaptureLogger(logger) as cap_logger:
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
- _ = DiffusionPipeline.from_pretrained(
- local_repo_id,
- safety_checker=None,
- variant="fp16",
- use_safetensors=True,
- )
- assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
+ _ = DiffusionPipeline.from_pretrained(
+ local_repo_id,
+ safety_checker=None,
+ variant="fp16",
+ use_safetensors=True,
+ )
+ assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
def test_download_safetensors_only_variant_exists_for_model(self):
variant = None
@@ -612,7 +613,7 @@ def test_download_safetensors_only_variant_exists_for_model(self):
variant=variant,
use_safetensors=use_safetensors,
)
- assert "Error no file name" in str(error_context.exception)
+ assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -628,6 +629,7 @@ def test_download_safetensors_only_variant_exists_for_model(self):
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_download_bin_only_variant_exists_for_model(self):
variant = None
use_safetensors = False
@@ -671,8 +673,9 @@ def test_download_safetensors_variant_does_not_exist_for_model(self):
use_safetensors=use_safetensors,
)
- assert "Error no file name" in str(error_context.exception)
+ assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_download_bin_variant_does_not_exist_for_model(self):
variant = "no_ema"
use_safetensors = False
@@ -688,6 +691,7 @@ def test_download_bin_variant_does_not_exist_for_model(self):
)
assert "Error no file name" in str(error_context.exception)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_local_save_load_index(self):
prompt = "hello"
for variant in [None, "fp16"]:
@@ -1104,6 +1108,21 @@ def test_remote_auto_custom_pipe(self):
assert images.shape == (1, 64, 64, 3)
+ def test_remote_custom_pipe_with_dot_in_name(self):
+ # make sure that trust remote code has to be passed
+ with self.assertRaises(ValueError):
+ pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name")
+
+ pipeline = DiffusionPipeline.from_pretrained("akasharidas/ddpm-cifar10-32-dot.in.name", trust_remote_code=True)
+
+ assert pipeline.__class__.__name__ == "CustomPipeline"
+
+ pipeline = pipeline.to(torch_device)
+ images, output_str = pipeline(num_inference_steps=2, output_type="np")
+
+ assert images[0].shape == (1, 32, 32, 3)
+ assert output_str == "This is a test"
+
def test_local_custom_pipeline_repo(self):
local_custom_pipeline_path = get_tests_dir("fixtures/custom_pipeline")
pipeline = DiffusionPipeline.from_pretrained(
@@ -1202,13 +1221,13 @@ def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def dummy_image(self):
batch_size = 1
@@ -1567,6 +1586,7 @@ def test_save_safe_serialization(self):
assert pipeline.scheduler is not None
assert pipeline.feature_extractor is not None
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_no_pytorch_download_when_doing_safetensors(self):
# by default we don't download
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -1586,6 +1606,7 @@ def test_no_pytorch_download_when_doing_safetensors(self):
# pytorch does not
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_no_safetensors_download_when_doing_pytorch(self):
use_safetensors = False
@@ -1812,7 +1833,12 @@ def test_pipe_same_device_id_offload(self):
feature_extractor=self.dummy_extractor,
)
- sd.enable_model_cpu_offload(gpu_id=5)
+ # `enable_model_cpu_offload` detects device type when not passed
+ # `enable_model_cpu_offload` raises ValueError if detected device is `cpu`
+ # This test only checks whether `_offload_gpu_id` is set correctly
+ # So the device passed can be any supported `torch.device` type
+ # This allows us to keep the test under `PipelineFastTests`
+ sd.enable_model_cpu_offload(gpu_id=5, device="cuda")
assert sd._offload_gpu_id == 5
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5
@@ -1866,6 +1892,7 @@ def test_dduf_raises_error_with_connected_pipeline(self):
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True
)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_wrong_model(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context:
@@ -1988,7 +2015,9 @@ def test_from_save_pretrained(self):
reason="Torch Dynamo isn't yet supported for Python 3.12.",
)
def test_from_save_pretrained_dynamo(self):
- run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=None)
+ torch.compiler.rest()
+ with torch._inductor.utils.fresh_inductor_cache():
+ run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=None)
def test_from_pretrained_hub(self):
model_path = "google/ddpm-cifar10-32"
@@ -2175,3 +2204,264 @@ def test_ddpm_ddim_equality_batched(self):
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
+
+
+@slow
+@require_torch_2
+@require_torch_accelerator
+@require_peft_backend
+@require_peft_version_greater("0.14.0")
+@is_torch_compile
+class TestLoraHotSwappingForPipeline(unittest.TestCase):
+ """Test that hotswapping does not result in recompilation in a pipeline.
+
+ We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
+ tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require
+ recompilation.
+
+ See
+ https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
+ for the analogous PEFT test.
+
+ """
+
+ def tearDown(self):
+ # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
+ # there will be recompilation errors, as torch caches the model when run in the same process.
+ super().tearDown()
+ torch.compiler.reset()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
+ # from diffusers test_models_unet_2d_condition.py
+ from peft import LoraConfig
+
+ unet_lora_config = LoraConfig(
+ r=lora_rank,
+ lora_alpha=lora_alpha,
+ target_modules=target_modules,
+ init_lora_weights=False,
+ use_dora=False,
+ )
+ return unet_lora_config
+
+ def get_lora_state_dicts(self, modules_to_save, adapter_name):
+ from peft import get_peft_model_state_dict
+
+ state_dicts = {}
+ for module_name, module in modules_to_save.items():
+ if module is not None:
+ state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(
+ module, adapter_name=adapter_name
+ )
+ return state_dicts
+
+ def get_dummy_input(self):
+ pipeline_inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "num_inference_steps": 5,
+ "guidance_scale": 6.0,
+ "output_type": "np",
+ "return_dict": False,
+ }
+ return pipeline_inputs
+
+ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
+ """
+ Check that hotswapping works on a pipeline.
+
+ Steps:
+ - create 2 LoRA adapters and save them
+ - load the first adapter
+ - hotswap the second adapter
+ - check that the outputs are correct
+ - optionally compile the model
+
+ Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
+ fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
+ fine.
+ """
+ # create 2 adapters with different ranks and alphas
+ dummy_input = self.get_dummy_input()
+ pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
+ alpha0, alpha1 = rank0, rank1
+ max_rank = max([rank0, rank1])
+ if target_modules1 is None:
+ target_modules1 = target_modules0[:]
+ lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
+ lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
+
+ torch.manual_seed(0)
+ pipeline.unet.add_adapter(lora_config0, adapter_name="adapter0")
+ output0_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
+
+ torch.manual_seed(1)
+ pipeline.unet.add_adapter(lora_config1, adapter_name="adapter1")
+ pipeline.unet.set_adapter("adapter1")
+ output1_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
+
+ # sanity check
+ tol = 1e-3
+ assert not np.allclose(output0_before, output1_before, atol=tol, rtol=tol)
+ assert not (output0_before == 0).all()
+ assert not (output1_before == 0).all()
+
+ with tempfile.TemporaryDirectory() as tmp_dirname:
+ # save the adapter checkpoints
+ lora0_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter0")
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts
+ )
+ lora1_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter1")
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts
+ )
+ del pipeline
+
+ # load the first adapter
+ pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
+ if do_compile or (rank0 != rank1):
+ # no need to prepare if the model is not compiled or if the ranks are identical
+ pipeline.enable_lora_hotswap(target_rank=max_rank)
+
+ file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors")
+ file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors")
+
+ pipeline.load_lora_weights(file_name0)
+ if do_compile:
+ pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead")
+
+ output0_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
+
+ # sanity check: still same result
+ assert np.allclose(output0_before, output0_after, atol=tol, rtol=tol)
+
+ # hotswap the 2nd adapter
+ pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0")
+ output1_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0]
+
+ # sanity check: since it's the same LoRA, the results should be identical
+ assert np.allclose(output1_before, output1_after, atol=tol, rtol=tol)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_pipeline(self, rank0, rank1):
+ self.check_pipeline_hotswap(
+ do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
+ )
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_compiled_pipline_linear(self, rank0, rank1):
+ # It's important to add this context to raise an error on recompilation
+ target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
+ self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1):
+ # It's important to add this context to raise an error on recompilation
+ target_modules = ["conv", "conv1", "conv2"]
+ with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
+ self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
+ def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1):
+ # It's important to add this context to raise an error on recompilation
+ target_modules = ["to_q", "conv"]
+ with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
+ self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
+
+ def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
+ # ensure that enable_lora_hotswap is called before loading the first adapter
+ lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
+ pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
+ pipeline.unet.add_adapter(lora_config)
+ msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
+ with self.assertRaisesRegex(RuntimeError, msg):
+ pipeline.enable_lora_hotswap(target_rank=32)
+
+ def test_enable_lora_hotswap_called_after_adapter_added_warns(self):
+ # ensure that enable_lora_hotswap is called before loading the first adapter
+ from diffusers.loaders.peft import logger
+
+ lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
+ pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
+ pipeline.unet.add_adapter(lora_config)
+ msg = (
+ "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
+ )
+ with self.assertLogs(logger=logger, level="WARNING") as cm:
+ pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn")
+ assert any(msg in log for log in cm.output)
+
+ def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
+ # check possibility to ignore the error/warning
+ lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
+ pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
+ pipeline.unet.add_adapter(lora_config)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always") # Capture all warnings
+ pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn")
+ self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
+
+ def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
+ # check that wrong argument value raises an error
+ lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
+ pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
+ pipeline.unet.add_adapter(lora_config)
+ msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
+ with self.assertRaisesRegex(ValueError, msg):
+ pipeline.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
+
+ def test_hotswap_second_adapter_targets_more_layers_raises(self):
+ # check the error and log
+ from diffusers.loaders.peft import logger
+
+ # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
+ target_modules0 = ["to_q"]
+ target_modules1 = ["to_q", "to_k"]
+ with self.assertRaises(RuntimeError): # peft raises RuntimeError
+ with self.assertLogs(logger=logger, level="ERROR") as cm:
+ self.check_pipeline_hotswap(
+ do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
+ )
+ assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
+
+ def test_hotswap_component_not_supported_raises(self):
+ # right now, not some components don't support hotswapping, e.g. the text_encoder
+ from peft import LoraConfig
+
+ pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
+ lora_config0 = LoraConfig(target_modules=["q_proj"])
+ lora_config1 = LoraConfig(target_modules=["q_proj"])
+
+ pipeline.text_encoder.add_adapter(lora_config0, adapter_name="adapter0")
+ pipeline.text_encoder.add_adapter(lora_config1, adapter_name="adapter1")
+
+ with tempfile.TemporaryDirectory() as tmp_dirname:
+ # save the adapter checkpoints
+ lora0_state_dicts = self.get_lora_state_dicts(
+ {"text_encoder": pipeline.text_encoder}, adapter_name="adapter0"
+ )
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts
+ )
+ lora1_state_dicts = self.get_lora_state_dicts(
+ {"text_encoder": pipeline.text_encoder}, adapter_name="adapter1"
+ )
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts
+ )
+ del pipeline
+
+ # load the first adapter
+ pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
+ file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors")
+ file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors")
+
+ pipeline.load_lora_weights(file_name0)
+ msg = re.escape(
+ "At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`"
+ )
+ with self.assertRaisesRegex(ValueError, msg):
+ pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0")
diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py
index 561a9011c6ae..f3c639c367f7 100644
--- a/tests/pipelines/test_pipelines_auto.py
+++ b/tests/pipelines/test_pipelines_auto.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,7 +35,8 @@
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)
-from diffusers.utils.testing_utils import slow
+
+from ..testing_utils import slow
PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
diff --git a/tests/pipelines/test_pipelines_combined.py b/tests/pipelines/test_pipelines_combined.py
index adedd54fea40..fffc053bae3f 100644
--- a/tests/pipelines/test_pipelines_combined.py
+++ b/tests/pipelines/test_pipelines_combined.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index d3e39e363f91..7db5f4da89ca 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -9,6 +9,7 @@
import numpy as np
import PIL.Image
+import pytest
import torch
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
@@ -33,9 +34,12 @@
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
+from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
+from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
+from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
@@ -46,18 +50,6 @@
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
-from diffusers.utils.testing_utils import (
- CaptureLogger,
- backend_empty_cache,
- require_accelerate_version_greater,
- require_accelerator,
- require_hf_hub_version_greater,
- require_torch,
- require_torch_gpu,
- require_transformers_version_greater,
- skip_mps,
- torch_device,
-)
from ..models.autoencoders.vae import (
get_asym_autoencoder_kl_config,
@@ -71,6 +63,19 @@
create_ip_adapter_state_dict,
)
from ..others.test_utils import TOKEN, USER, is_staging_test
+from ..testing_utils import (
+ CaptureLogger,
+ backend_empty_cache,
+ numpy_cosine_similarity_distance,
+ require_accelerate_version_greater,
+ require_accelerator,
+ require_hf_hub_version_greater,
+ require_torch,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ skip_mps,
+ torch_device,
+)
def to_np(tensor):
@@ -96,6 +101,20 @@ def check_qkv_fusion_processors_exist(model):
return all(p.startswith("Fused") for p in proc_names)
+def check_qkv_fused_layers_exist(model, layer_names):
+ is_fused_submodules = []
+ for submodule in model.modules():
+ if not isinstance(submodule, AttentionModuleMixin) or not submodule._supports_qkv_fusion:
+ continue
+ is_fused_attribute_set = submodule.fused_projections
+ is_fused_layer = True
+ for layer in layer_names:
+ is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
+ is_fused = is_fused_attribute_set and is_fused_layer
+ is_fused_submodules.append(is_fused)
+ return all(is_fused_submodules)
+
+
class SDFunctionTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
@@ -191,12 +210,12 @@ def test_freeu(self):
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0]
- assert not np.allclose(
- output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
- ), "Enabling of FreeU should lead to different results."
- assert np.allclose(
- output, output_no_freeu, atol=1e-2
- ), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
+ assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
+ "Enabling of FreeU should lead to different results."
+ )
+ assert np.allclose(output, output_no_freeu, atol=1e-2), (
+ f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
+ )
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -217,12 +236,12 @@ def test_fused_qkv_projections(self):
and hasattr(component, "original_attn_processors")
and component.original_attn_processors is not None
):
- assert check_qkv_fusion_processors_exist(
- component
- ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
- assert check_qkv_fusion_matches_attn_procs_length(
- component, component.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
+ assert check_qkv_fusion_processors_exist(component), (
+ "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ )
+ assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), (
+ "Something wrong with the attention processors concerning the fused QKV projections."
+ )
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
@@ -235,15 +254,15 @@ def test_fused_qkv_projections(self):
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
- assert np.allclose(
- original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
- ), "Fusion of QKV projections shouldn't affect the outputs."
- assert np.allclose(
- image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
- assert np.allclose(
- original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
- ), "Original outputs should match when fused QKV projections are disabled."
+ assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
+ "Fusion of QKV projections shouldn't affect the outputs."
+ )
+ assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ )
+ assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
+ "Original outputs should match when fused QKV projections are disabled."
+ )
class IPAdapterTesterMixin:
@@ -521,7 +540,8 @@ def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
inputs["negative_prompt"] = ""
- inputs["true_cfg_scale"] = 4.0
+ if "true_cfg_scale" in inspect.signature(self.pipeline_class.__call__).parameters:
+ inputs["true_cfg_scale"] = 4.0
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
@@ -542,7 +562,11 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
- image_embed_dim = pipe.transformer.config.pooled_projection_dim
+ image_embed_dim = (
+ pipe.transformer.config.pooled_projection_dim
+ if hasattr(pipe.transformer.config, "pooled_projection_dim")
+ else 768
+ )
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
@@ -909,9 +933,9 @@ def test_from_pipe_consistent_forward_pass(self, expected_max_diff=1e-3):
for component in pipe_original.components.values():
if hasattr(component, "attn_processors"):
- assert all(
- type(proc) == AttnProcessor for proc in component.attn_processors.values()
- ), "`from_pipe` changed the attention processor in original pipeline."
+ assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), (
+ "`from_pipe` changed the attention processor in original pipeline."
+ )
@require_accelerator
@require_accelerate_version_greater("0.14.0")
@@ -1111,12 +1135,22 @@ def callback_cfg_params(self) -> frozenset:
def setUp(self):
# clean up the VRAM before each test
super().setUp()
+ torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
+ # Skip tests for pipelines that inherit from DeprecatedPipelineMixin
+ from diffusers.pipelines.pipeline_utils import DeprecatedPipelineMixin
+
+ if hasattr(self, "pipeline_class") and issubclass(self.pipeline_class, DeprecatedPipelineMixin):
+ import pytest
+
+ pytest.skip(reason=f"Deprecated Pipeline: {self.pipeline_class.__name__}")
+
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
+ torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
@@ -1362,7 +1396,6 @@ def test_float16_inference(self, expected_max_diff=5e-2):
for component in pipe_fp16.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
-
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
@@ -1370,25 +1403,38 @@ def test_float16_inference(self, expected_max_diff=5e-2):
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
-
output = pipe(**inputs)[0]
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
-
output_fp16 = pipe_fp16(**fp16_inputs)[0]
- max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
- self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
+ if isinstance(output, torch.Tensor):
+ output = output.cpu()
+ output_fp16 = output_fp16.cpu()
+
+ max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
+ assert max_diff < expected_max_diff
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
- if hasattr(module, "half"):
+ # Account for components with _keep_in_fp32_modules
+ if hasattr(module, "_keep_in_fp32_modules") and module._keep_in_fp32_modules is not None:
+ for name, param in module.named_parameters():
+ if any(
+ module_to_keep_in_fp32 in name.split(".")
+ for module_to_keep_in_fp32 in module._keep_in_fp32_modules
+ ):
+ param.data = param.data.to(torch_device).to(torch.float32)
+ else:
+ param.data = param.data.to(torch_device).to(torch.float16)
+
+ elif hasattr(module, "half"):
components[name] = module.to(torch_device).half()
pipe = self.pipeline_class(**components)
@@ -1427,6 +1473,8 @@ def test_save_load_float16(self, expected_max_diff=1e-2):
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
+ if not self.pipeline_class._optional_components:
+ return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
@@ -1483,8 +1531,8 @@ def test_to_device(self):
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
self.assertTrue(all(device == torch_device for device in model_devices))
- output_cuda = pipe(**self.get_dummy_inputs(torch_device))[0]
- self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+ output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
+ self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
@@ -1675,11 +1723,11 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
pipe.set_progress_bar_config(disable=None)
- pipe.enable_model_cpu_offload(device=torch_device)
+ pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(generator_device)
output_with_offload = pipe(**inputs)[0]
- pipe.enable_model_cpu_offload(device=torch_device)
+ pipe.enable_model_cpu_offload()
inputs = self.get_dummy_inputs(generator_device)
output_with_offload_twice = pipe(**inputs)[0]
@@ -2094,11 +2142,11 @@ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=
with torch.no_grad():
encoded_prompt_outputs = pipe_with_just_text_encoder.encode_prompt(**encode_prompt_inputs)
- # Programatically determine the reutrn names of `encode_prompt.`
- ast_vistor = ReturnNameVisitor()
- encode_prompt_tree = ast_vistor.get_ast_tree(cls=self.pipeline_class)
- ast_vistor.visit(encode_prompt_tree)
- prompt_embed_kwargs = ast_vistor.return_names
+ # Programmatically determine the return names of `encode_prompt.`
+ ast_visitor = ReturnNameVisitor()
+ encode_prompt_tree = ast_visitor.get_ast_tree(cls=self.pipeline_class)
+ ast_visitor.visit(encode_prompt_tree)
+ prompt_embed_kwargs = ast_visitor.return_names
prompt_embeds_kwargs = dict(zip(prompt_embed_kwargs, encoded_prompt_outputs))
# Pack the outputs of `encode_prompt`.
@@ -2210,7 +2258,7 @@ def test_layerwise_casting_inference(self):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)[0]
- @require_torch_gpu
+ @require_torch_accelerator
def test_group_offloading_inference(self):
if not self.test_group_offloading:
return
@@ -2224,7 +2272,7 @@ def create_pipe():
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
- # tiling is enabled and a forward pass is run, when cuda streams are used, the execution order of
+ # tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
for component_name in [
@@ -2255,9 +2303,10 @@ def enable_group_offload_on_component(pipe, group_offloading_kwargs):
if hasattr(module, "_diffusers_hook")
)
)
- for component_name in ["vae", "vqvae"]:
- if hasattr(pipe, component_name):
- getattr(pipe, component_name).to(torch_device)
+ for component_name in ["vae", "vqvae", "image_encoder"]:
+ component = getattr(pipe, component_name, None)
+ if isinstance(component, torch.nn.Module):
+ component.to(torch_device)
def run_forward(pipe):
torch.manual_seed(0)
@@ -2289,7 +2338,6 @@ def test_torch_dtype_dict(self):
self.skipTest("No dummy components defined.")
pipe = self.pipeline_class(**components)
-
specified_key = next(iter(components.keys()))
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
@@ -2306,6 +2354,96 @@ def test_torch_dtype_dict(self):
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
+ @require_torch_accelerator
+ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["generator"] = torch.manual_seed(0)
+ out = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device)
+ for component in loaded_pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ inputs["generator"] = torch.manual_seed(0)
+ loaded_out = loaded_pipe(**inputs)[0]
+ max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ @require_torch_accelerator
+ def test_pipeline_level_group_offloading_sanity_checks(self):
+ components = self.get_dummy_components()
+ pipe: DiffusionPipeline = self.pipeline_class(**components)
+
+ for name, component in pipe.components.items():
+ if hasattr(component, "_supports_group_offloading"):
+ if not component._supports_group_offloading:
+ pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
+
+ module_names = sorted(
+ [name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]
+ )
+ exclude_module_name = module_names[0]
+ offload_device = "cpu"
+ pipe.enable_group_offload(
+ onload_device=torch_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ exclude_modules=exclude_module_name,
+ )
+ excluded_module = getattr(pipe, exclude_module_name)
+ self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)
+
+ for name, component in pipe.components.items():
+ if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
+ # `component.device` prints the `onload_device` type. We should probably override the
+ # `device` property in `ModelMixin`.
+ component_device = next(component.parameters())[0].device
+ self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
+
+ @require_torch_accelerator
+ def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe: DiffusionPipeline = self.pipeline_class(**components)
+
+ for name, component in pipe.components.items():
+ if hasattr(component, "_supports_group_offloading"):
+ if not component._supports_group_offloading:
+ pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
+
+ # Regular inference.
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["generator"] = torch.manual_seed(0)
+ out = pipe(**inputs)[0]
+
+ pipe.to("cpu")
+ del pipe
+
+ # Inference with offloading
+ pipe: DiffusionPipeline = self.pipeline_class(**components)
+ offload_device = "cpu"
+ pipe.enable_group_offload(
+ onload_device=torch_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ )
+ pipe.set_progress_bar_config(disable=None)
+ inputs["generator"] = torch.manual_seed(0)
+ out_offload = pipe(**inputs)[0]
+
+ max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
@@ -2569,12 +2707,12 @@ def test_pyramid_attention_broadcast_inference(self, expected_atol: float = 0.2)
image_slice_pab_disabled = output.flatten()
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
- assert np.allclose(
- original_image_slice, image_slice_pab_enabled, atol=expected_atol
- ), "PAB outputs should not differ much in specified timestep range."
- assert np.allclose(
- original_image_slice, image_slice_pab_disabled, atol=1e-4
- ), "Outputs from normal inference and after disabling cache should not differ."
+ assert np.allclose(original_image_slice, image_slice_pab_enabled, atol=expected_atol), (
+ "PAB outputs should not differ much in specified timestep range."
+ )
+ assert np.allclose(original_image_slice, image_slice_pab_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
class FasterCacheTesterMixin:
@@ -2631,7 +2769,7 @@ def run_forward(pipe):
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
- output = run_forward(pipe).flatten().flatten()
+ output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
@@ -2639,12 +2777,12 @@ def run_forward(pipe):
output = run_forward(pipe).flatten()
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
- assert np.allclose(
- original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
- ), "FasterCache outputs should not differ much in specified timestep range."
- assert np.allclose(
- original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
- ), "Outputs from normal inference and after disabling cache should not differ."
+ assert np.allclose(original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol), (
+ "FasterCache outputs should not differ much in specified timestep range."
+ )
+ assert np.allclose(original_image_slice, image_slice_faster_cache_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
def test_faster_cache_state(self):
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
@@ -2738,6 +2876,106 @@ def faster_cache_state_check_callback(pipe, i, t, kwargs):
self.assertTrue(state.cache is None, "Cache should be reset to None.")
+# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
+# of the box once there is better cache support/implementation
+class FirstBlockCacheTesterMixin:
+ # threshold is intentionally set higher than usual values since we're testing with random unconverged models
+ # that will not satisfy the expected properties of the denoiser for caching to be effective
+ first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
+
+ def test_first_block_cache_inference(self, expected_atol: float = 0.1):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ def create_pipe():
+ torch.manual_seed(0)
+ num_layers = 2
+ components = self.get_dummy_components(num_layers=num_layers)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ return pipe
+
+ def run_forward(pipe):
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ return pipe(**inputs)[0]
+
+ # Run inference without FirstBlockCache
+ pipe = create_pipe()
+ output = run_forward(pipe).flatten()
+ original_image_slice = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FirstBlockCache enabled
+ pipe = create_pipe()
+ pipe.transformer.enable_cache(self.first_block_cache_config)
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FirstBlockCache disabled
+ pipe.transformer.disable_cache()
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
+
+ assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
+ "FirstBlockCache outputs should not differ much."
+ )
+ assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
+
+
+class TaylorSeerCacheTesterMixin:
+ taylorseer_cache_config = TaylorSeerCacheConfig(
+ cache_interval=5,
+ disable_cache_before_step=10,
+ max_order=1,
+ taylor_factors_dtype=torch.bfloat16,
+ use_lite_mode=True,
+ )
+
+ def test_taylorseer_cache_inference(self, expected_atol: float = 0.1):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ def create_pipe():
+ torch.manual_seed(0)
+ num_layers = 2
+ components = self.get_dummy_components(num_layers=num_layers)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ return pipe
+
+ def run_forward(pipe):
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 50
+ return pipe(**inputs)[0]
+
+ # Run inference without TaylorSeerCache
+ pipe = create_pipe()
+ output = run_forward(pipe).flatten()
+ original_image_slice = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with TaylorSeerCache enabled
+ pipe = create_pipe()
+ pipe.transformer.enable_cache(self.taylorseer_cache_config)
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with TaylorSeerCache disabled
+ pipe.transformer.disable_cache()
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
+
+ assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
+ "TaylorSeerCache outputs should not differ much."
+ )
+ assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
+
+
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.
diff --git a/tests/pipelines/test_pipelines_flax.py b/tests/pipelines/test_pipelines_flax.py
deleted file mode 100644
index efd3da4c6127..000000000000
--- a/tests/pipelines/test_pipelines_flax.py
+++ /dev/null
@@ -1,260 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import tempfile
-import unittest
-
-import numpy as np
-
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import require_flax, slow
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard
-
- from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
-
-
-@require_flax
-class DownloadTests(unittest.TestCase):
- def test_download_only_pytorch(self):
- with tempfile.TemporaryDirectory() as tmpdirname:
- # pipeline has Flax weights
- _ = FlaxDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
- )
-
- all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
- files = [item for sublist in all_root_files for item in sublist]
-
- # None of the downloaded files should be a PyTorch file even if we have some here:
- # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
- assert not any(f.endswith(".bin") for f in files)
-
-
-@slow
-@require_flax
-class FlaxPipelineTests(unittest.TestCase):
- def test_dummy_all_tpus(self):
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
- )
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 4
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 64, 64, 3)
- if jax.device_count() == 8:
- assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
- assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
-
- images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
- assert len(images_pil) == num_samples
-
- def test_stable_diffusion_v1_4(self):
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
- )
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 512, 512, 3)
- if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-2
- assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
-
- def test_stable_diffusion_v1_4_bfloat_16(self):
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16, safety_checker=None
- )
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 512, 512, 3)
- if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
- assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
-
- def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16
- )
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 512, 512, 3)
- if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
- assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
-
- def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
- scheduler = FlaxDDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- set_alpha_to_one=False,
- steps_offset=1,
- )
-
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- variant="bf16",
- dtype=jnp.bfloat16,
- scheduler=scheduler,
- safety_checker=None,
- )
- scheduler_state = scheduler.create_state()
-
- params["scheduler"] = scheduler_state
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 512, 512, 3)
- if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 5e-2
- assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
-
- def test_jax_memory_efficient_attention(self):
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples)
-
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- variant="bf16",
- dtype=jnp.bfloat16,
- safety_checker=None,
- )
-
- params = replicate(params)
- prompt_ids = pipeline.prepare_inputs(prompt)
- prompt_ids = shard(prompt_ids)
- images = pipeline(prompt_ids, params, prng_seed, jit=True).images
- assert images.shape == (num_samples, 1, 512, 512, 3)
- slice = images[2, 0, 256, 10:17, 1]
-
- # With memory efficient attention
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- variant="bf16",
- dtype=jnp.bfloat16,
- safety_checker=None,
- use_memory_efficient_attention=True,
- )
-
- params = replicate(params)
- prompt_ids = pipeline.prepare_inputs(prompt)
- prompt_ids = shard(prompt_ids)
- images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images
- assert images_eff.shape == (num_samples, 1, 512, 512, 3)
- slice_eff = images[2, 0, 256, 10:17, 1]
-
- # I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum`
- # over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now.
- assert abs(slice_eff - slice).max() < 1e-2
diff --git a/tests/pipelines/test_pipelines_onnx_common.py b/tests/pipelines/test_pipelines_onnx_common.py
index 575ecd007531..fa077efb8ab0 100644
--- a/tests/pipelines/test_pipelines_onnx_common.py
+++ b/tests/pipelines/test_pipelines_onnx_common.py
@@ -1,4 +1,4 @@
-from diffusers.utils.testing_utils import require_onnxruntime
+from ..testing_utils import require_onnxruntime
@require_onnxruntime
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
deleted file mode 100644
index 5d0f8299f68e..000000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoSDPipeline, UNet3DConditionModel
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- load_numpy,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, SDFunctionTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class TextToVideoSDPipelineFastTests(PipelineTesterMixin, SDFunctionTesterMixin, unittest.TestCase):
- pipeline_class = TextToVideoSDPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- # No `output_type`.
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet3DConditionModel(
- block_out_channels=(8, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=4,
- attention_head_dim=4,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=(8,),
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=16,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = TextToVideoSDPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = sd_pipe(**inputs).frames
-
- image_slice = frames[0][0][-3:, -3:, -1]
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.8093, 0.2751, 0.6976, 0.5927, 0.4616, 0.4336, 0.5094, 0.5683, 0.4796])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @unittest.skipIf(torch_device != "cuda", reason="Feature isn't heavily used. Test in CUDA environment only.")
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False, expected_max_diff=3e-3)
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=1e-2)
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_single_identical(self):
- pass
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_num_images_per_prompt(self):
- pass
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@slow
-@skip_mps
-@require_torch_accelerator
-class TextToVideoSDPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_two_step_model(self):
- expected_video = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy"
- )
-
- pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
- pipe = pipe.to(torch_device)
-
- prompt = "Spiderman is surfing"
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
- assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4
-
- def test_two_step_model_with_freeu(self):
- expected_video = []
-
- pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
- pipe = pipe.to(torch_device)
-
- prompt = "Spiderman is surfing"
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
- video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
- video = video_frames[0, 0, -3:, -3:, -1].flatten()
-
- expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475]
-
- assert np.abs(expected_video - video).mean() < 5e-2
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py
deleted file mode 100644
index f1bf6ee52206..000000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import torch
-
-from diffusers import DDIMScheduler, TextToVideoZeroPipeline
-from diffusers.utils.testing_utils import load_pt, nightly, require_torch_gpu
-
-from ..test_pipelines_common import assert_mean_pixel_difference
-
-
-@nightly
-@require_torch_gpu
-class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_full_model(self):
- model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- generator = torch.Generator(device="cuda").manual_seed(0)
-
- prompt = "A bear is playing a guitar on Times Square"
- result = pipe(prompt=prompt, generator=generator).images
-
- expected_result = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt"
- )
-
- assert_mean_pixel_difference(result, expected_result)
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
deleted file mode 100644
index db24767b60fc..000000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
+++ /dev/null
@@ -1,403 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import inspect
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoZeroSDXLPipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- nightly,
- require_accelerate_version_greater,
- require_accelerator,
- require_torch_gpu,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
- pipeline_class = TextToVideoZeroSDXLPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- generator_device = "cpu"
-
- def get_dummy_components(self, seed=0):
- torch.manual_seed(seed)
- unet = UNet2DConditionModel(
- block_out_channels=(2, 4),
- layers_per_block=2,
- sample_size=2,
- norm_num_groups=2,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=80, # 6 * 8 + 32
- cross_attention_dim=64,
- )
- scheduler = DDIMScheduler(
- num_train_timesteps=1000,
- beta_start=0.0001,
- beta_end=0.02,
- beta_schedule="linear",
- trained_betas=None,
- clip_sample=True,
- set_alpha_to_one=True,
- steps_offset=0,
- prediction_type="epsilon",
- thresholding=False,
- dynamic_thresholding_ratio=0.995,
- clip_sample_range=1.0,
- sample_max_value=1.0,
- timestep_spacing="leading",
- rescale_betas_zero_snr=False,
- )
- torch.manual_seed(seed)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(seed)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "image_encoder": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A panda dancing in Antarctica",
- "generator": generator,
- "num_inference_steps": 5,
- "t0": 1,
- "t1": 3,
- "height": 64,
- "width": 64,
- "video_length": 3,
- "output_type": "np",
- }
- return inputs
-
- def get_generator(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- return generator
-
- def test_text_to_video_zero_sdxl(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- result = pipe(**inputs).images
-
- first_frame_slice = result[0, -3:, -3:, -1]
- last_frame_slice = result[-1, -3:, -3:, 0]
-
- expected_slice1 = np.array(
- [0.6008109, 0.73051643, 0.51778656, 0.55817354, 0.45222935, 0.45998418, 0.57017255, 0.54874814, 0.47078788]
- )
- expected_slice2 = np.array(
- [0.6011751, 0.47420046, 0.41660714, 0.6472957, 0.41261768, 0.5438129, 0.7401535, 0.6756011, 0.53652245]
- )
-
- assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
- assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_attention_slicing_forward_pass(self):
- pass
-
- def test_cfg(self):
- sig = inspect.signature(self.pipeline_class.__call__)
- if "guidance_scale" not in sig.parameters:
- return
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
-
- inputs["guidance_scale"] = 1.0
- out_no_cfg = pipe(**inputs)[0]
-
- inputs["guidance_scale"] = 7.5
- out_cfg = pipe(**inputs)[0]
-
- assert out_cfg.shape == out_no_cfg.shape
-
- def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(self.generator_device))[0]
- output_tuple = pipe(**self.get_dummy_inputs(self.generator_device), return_dict=False)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
- self.assertLess(max_diff, expected_max_difference)
-
- @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
- @require_accelerator
- def test_float16_inference(self, expected_max_diff=5e-2):
- components = self.get_dummy_components()
- for name, module in components.items():
- if hasattr(module, "half"):
- components[name] = module.to(torch_device).half()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- components = self.get_dummy_components()
- pipe_fp16 = self.pipeline_class(**components)
- pipe_fp16.to(torch_device, torch.float16)
- pipe_fp16.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- # # Reset generator in case it is used inside dummy inputs
- if "generator" in inputs:
- inputs["generator"] = self.get_generator(self.generator_device)
-
- output = pipe(**inputs)[0]
-
- fp16_inputs = self.get_dummy_inputs(self.generator_device)
- # Reset generator in case it is used inside dummy inputs
- if "generator" in fp16_inputs:
- fp16_inputs["generator"] = self.get_generator(self.generator_device)
-
- output_fp16 = pipe_fp16(**fp16_inputs)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
- self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
-
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_inference_batch_single_identical(self):
- pass
-
- @require_accelerator
- @require_accelerate_version_greater("0.17.0")
- def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output_without_offload = pipe(**inputs)[0]
-
- pipe.enable_model_cpu_offload(device=torch_device)
- inputs = self.get_dummy_inputs(self.generator_device)
- output_with_offload = pipe(**inputs)[0]
-
- max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
- self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_pipeline_call_signature(self):
- pass
-
- @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
- @require_accelerator
- def test_save_load_float16(self, expected_max_diff=1e-2):
- components = self.get_dummy_components()
- for name, module in components.items():
- if hasattr(module, "half"):
- components[name] = module.to(torch_device).half()
-
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for name, component in pipe_loaded.components.items():
- if hasattr(component, "dtype"):
- self.assertTrue(
- component.dtype == torch.float16,
- f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
- )
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output_loaded = pipe_loaded(**inputs)[0]
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(
- max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
- )
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_save_load_local(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_save_load_optional_components(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_sequential_cpu_offload_forward_pass(self):
- pass
-
- @require_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- pass
-
-
-@nightly
-@require_torch_gpu
-class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_full_model(self):
- model_id = "stabilityai/stable-diffusion-xl-base-1.0"
- pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
- model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
- )
- pipe.enable_model_cpu_offload()
- pipe.enable_vae_slicing()
-
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- prompt = "A panda dancing in Antarctica"
- result = pipe(prompt=prompt, generator=generator).images
-
- first_frame_slice = result[0, -3:, -3:, -1]
- last_frame_slice = result[-1, -3:, -3:, 0]
-
- expected_slice1 = np.array([0.57, 0.57, 0.57, 0.57, 0.57, 0.56, 0.55, 0.56, 0.56])
- expected_slice2 = np.array([0.54, 0.53, 0.53, 0.53, 0.53, 0.52, 0.53, 0.53, 0.53])
-
- assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
- assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
deleted file mode 100644
index f44a8aa33c5a..000000000000
--- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- UNet3DConditionModel,
- VideoToVideoSDPipeline,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- is_flaky,
- nightly,
- numpy_cosine_similarity_distance,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import (
- TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
-)
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = VideoToVideoSDPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"video"}) - {"image", "width", "height"}
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"video"}) - {"image"}
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
- test_attention_slicing = False
-
- # No `output_type`.
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet3DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=32,
- attention_head_dim=4,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=True,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[
- 8,
- ],
- in_channels=3,
- out_channels=3,
- down_block_types=[
- "DownEncoderBlock2D",
- ],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- # 3 frames
- video = floats_tensor((1, 3, 3, 32, 32), rng=random.Random(seed)).to(device)
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "video": video,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- }
- return inputs
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = VideoToVideoSDPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = sd_pipe(**inputs).frames
- image_slice = frames[0][0][-3:, -3:, -1]
-
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @is_flaky()
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=0.001)
-
- @is_flaky()
- def test_dict_tuple_outputs_equivalent(self):
- super().test_dict_tuple_outputs_equivalent()
-
- @is_flaky()
- def test_save_load_local(self):
- super().test_save_load_local()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=5e-3)
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_single_identical(self):
- pass
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_num_images_per_prompt(self):
- pass
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@skip_mps
-class VideoToVideoSDPipelineSlowTests(unittest.TestCase):
- def test_two_step_model(self):
- pipe = VideoToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
-
- # 10 frames
- generator = torch.Generator(device="cpu").manual_seed(0)
- video = torch.randn((1, 10, 3, 320, 576), generator=generator)
-
- prompt = "Spiderman is surfing"
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames
-
- expected_array = np.array(
- [0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422]
- )
- output_array = video_frames[0, 0, :3, :3, 0].flatten()
- assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3
diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py
deleted file mode 100644
index 26a1bead0138..000000000000
--- a/tests/pipelines/unclip/test_unclip.py
+++ /dev/null
@@ -1,519 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
-from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- load_numpy,
- nightly,
- require_torch_gpu,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
-
-
-enable_full_determinism()
-
-
-class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = UnCLIPPipeline
- params = TEXT_TO_IMAGE_PARAMS - {
- "negative_prompt",
- "height",
- "width",
- "negative_prompt_embeds",
- "guidance_scale",
- "prompt_embeds",
- "cross_attention_kwargs",
- }
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- required_optional_params = [
- "generator",
- "return_dict",
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- test_xformers_attention = False
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def cross_attention_dim(self):
- return 100
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModelWithProjection(config)
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "num_attention_heads": 2,
- "attention_head_dim": 12,
- "embedding_dim": self.text_embedder_hidden_size,
- "num_layers": 1,
- }
-
- model = PriorTransformer(**model_kwargs)
- return model
-
- @property
- def dummy_text_proj(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "clip_embeddings_dim": self.text_embedder_hidden_size,
- "time_embed_dim": self.time_embed_dim,
- "cross_attention_dim": self.cross_attention_dim,
- }
-
- model = UnCLIPTextProjModel(**model_kwargs)
- return model
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "sample_size": 32,
- # RGB in channels
- "in_channels": 3,
- # Out channels is double in channels because predicts mean and variance
- "out_channels": 6,
- "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
- "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
- "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "layers_per_block": 1,
- "cross_attention_dim": self.cross_attention_dim,
- "attention_head_dim": 4,
- "resnet_time_scale_shift": "scale_shift",
- "class_embed_type": "identity",
- }
-
- model = UNet2DConditionModel(**model_kwargs)
- return model
-
- @property
- def dummy_super_res_kwargs(self):
- return {
- "sample_size": 64,
- "layers_per_block": 1,
- "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
- "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "in_channels": 6,
- "out_channels": 3,
- }
-
- @property
- def dummy_super_res_first(self):
- torch.manual_seed(0)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- @property
- def dummy_super_res_last(self):
- # seeded differently to get different unet than `self.dummy_super_res_first`
- torch.manual_seed(1)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- decoder = self.dummy_decoder
- text_proj = self.dummy_text_proj
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- super_res_first = self.dummy_super_res_first
- super_res_last = self.dummy_super_res_last
-
- prior_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="sample",
- num_train_timesteps=1000,
- clip_sample_range=5.0,
- )
-
- decoder_scheduler = UnCLIPScheduler(
- variance_type="learned_range",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- super_res_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- components = {
- "prior": prior,
- "decoder": decoder,
- "text_proj": text_proj,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "super_res_first": super_res_first,
- "super_res_last": super_res_last,
- "prior_scheduler": prior_scheduler,
- "decoder_scheduler": decoder_scheduler,
- "super_res_scheduler": super_res_scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "prior_num_inference_steps": 2,
- "decoder_num_inference_steps": 2,
- "super_res_num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_unclip(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(
- **self.get_dummy_inputs(device),
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.9988,
- 0.0028,
- 0.9997,
- 0.9984,
- 0.9965,
- 0.0029,
- 0.9986,
- 0.0025,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_passed_text_embed(self):
- device = torch.device("cpu")
-
- class DummyScheduler:
- init_noise_sigma = 1
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- prior = components["prior"]
- decoder = components["decoder"]
- super_res_first = components["super_res_first"]
- tokenizer = components["tokenizer"]
- text_encoder = components["text_encoder"]
-
- generator = torch.Generator(device=device).manual_seed(0)
- dtype = prior.dtype
- batch_size = 1
-
- shape = (batch_size, prior.config.embedding_dim)
- prior_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
- shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
- generator = torch.Generator(device=device).manual_seed(0)
- decoder_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- shape = (
- batch_size,
- super_res_first.config.in_channels // 2,
- super_res_first.config.sample_size,
- super_res_first.config.sample_size,
- )
- super_res_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "this is a prompt example"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = pipe(
- [prompt],
- generator=generator,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- prior_latents=prior_latents,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- output_type="np",
- )
- image = output.images
-
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- return_tensors="pt",
- )
- text_model_output = text_encoder(text_inputs.input_ids)
- text_attention_mask = text_inputs.attention_mask
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_text = pipe(
- generator=generator,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- prior_latents=prior_latents,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- text_model_output=text_model_output,
- text_attention_mask=text_attention_mask,
- output_type="np",
- )[0]
-
- # make sure passing text embeddings manually is identical
- assert np.abs(image - image_from_text).max() < 1e-4
-
- # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
- # because UnCLIP GPU undeterminism requires a looser check.
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
-
- self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference, expected_max_diff=0.01)
-
- # Overriding PipelineTesterMixin::test_inference_batch_single_identical
- # because UnCLIP undeterminism requires a looser check.
- @skip_mps
- def test_inference_batch_single_identical(self):
- additional_params_copy_to_batched_inputs = [
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- self._test_inference_batch_single_identical(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=5e-3
- )
-
- def test_inference_batch_consistent(self):
- additional_params_copy_to_batched_inputs = [
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- if torch_device == "mps":
- # TODO: MPS errors with larger batch sizes
- batch_sizes = [2, 3]
- self._test_inference_batch_consistent(
- batch_sizes=batch_sizes,
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
- )
- else:
- self._test_inference_batch_consistent(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs
- )
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local(expected_max_difference=5e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @unittest.skip("UnCLIP produces very large differences in fp16 vs fp32. Test is not useful.")
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1.0)
-
-
-@nightly
-class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_unclip_karlo_cpu_fp32(self):
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_horse_cpu.npy"
- )
-
- pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipeline(
- "horse",
- num_images_per_prompt=1,
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
- assert np.abs(expected_image - image).max() < 1e-1
-
-
-@nightly
-@require_torch_gpu
-class UnCLIPPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_unclip_karlo(self):
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_horse_fp16.npy"
- )
-
- pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
- pipeline = pipeline.to(torch_device)
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipeline(
- "horse",
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
-
- assert_mean_pixel_difference(image, expected_image)
-
- def test_unclip_pipeline_with_sequential_cpu_offloading(self):
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
-
- pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
- pipe.enable_sequential_cpu_offload()
-
- _ = pipe(
- "horse",
- num_images_per_prompt=1,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- output_type="np",
- )
-
- mem_bytes = torch.cuda.max_memory_allocated()
- # make sure that less than 7 GB is allocated
- assert mem_bytes < 7 * 10**9
diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py
deleted file mode 100644
index e402629fe1b9..000000000000
--- a/tests/pipelines/unclip/test_unclip_image_variation.py
+++ /dev/null
@@ -1,539 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextConfig,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- DiffusionPipeline,
- UnCLIPImageVariationPipeline,
- UnCLIPScheduler,
- UNet2DConditionModel,
- UNet2DModel,
-)
-from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- load_image,
- load_numpy,
- nightly,
- require_torch_gpu,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
-
-
-enable_full_determinism()
-
-
-class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = UnCLIPImageVariationPipeline
- params = IMAGE_VARIATION_PARAMS - {"height", "width", "guidance_scale"}
- batch_params = IMAGE_VARIATION_BATCH_PARAMS
-
- required_optional_params = [
- "generator",
- "return_dict",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- test_xformers_attention = False
- supports_dduf = False
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def cross_attention_dim(self):
- return 100
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModelWithProjection(config)
-
- @property
- def dummy_image_encoder(self):
- torch.manual_seed(0)
- config = CLIPVisionConfig(
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- num_hidden_layers=5,
- num_attention_heads=4,
- image_size=32,
- intermediate_size=37,
- patch_size=1,
- )
- return CLIPVisionModelWithProjection(config)
-
- @property
- def dummy_text_proj(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "clip_embeddings_dim": self.text_embedder_hidden_size,
- "time_embed_dim": self.time_embed_dim,
- "cross_attention_dim": self.cross_attention_dim,
- }
-
- model = UnCLIPTextProjModel(**model_kwargs)
- return model
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "sample_size": 32,
- # RGB in channels
- "in_channels": 3,
- # Out channels is double in channels because predicts mean and variance
- "out_channels": 6,
- "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
- "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
- "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "layers_per_block": 1,
- "cross_attention_dim": self.cross_attention_dim,
- "attention_head_dim": 4,
- "resnet_time_scale_shift": "scale_shift",
- "class_embed_type": "identity",
- }
-
- model = UNet2DConditionModel(**model_kwargs)
- return model
-
- @property
- def dummy_super_res_kwargs(self):
- return {
- "sample_size": 64,
- "layers_per_block": 1,
- "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
- "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "in_channels": 6,
- "out_channels": 3,
- }
-
- @property
- def dummy_super_res_first(self):
- torch.manual_seed(0)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- @property
- def dummy_super_res_last(self):
- # seeded differently to get different unet than `self.dummy_super_res_first`
- torch.manual_seed(1)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- def get_dummy_components(self):
- decoder = self.dummy_decoder
- text_proj = self.dummy_text_proj
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- super_res_first = self.dummy_super_res_first
- super_res_last = self.dummy_super_res_last
-
- decoder_scheduler = UnCLIPScheduler(
- variance_type="learned_range",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- super_res_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- image_encoder = self.dummy_image_encoder
-
- return {
- "decoder": decoder,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_proj": text_proj,
- "feature_extractor": feature_extractor,
- "image_encoder": image_encoder,
- "super_res_first": super_res_first,
- "super_res_last": super_res_last,
- "decoder_scheduler": decoder_scheduler,
- "super_res_scheduler": super_res_scheduler,
- }
-
- def get_dummy_inputs(self, device, seed=0, pil_image=True):
- input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- if pil_image:
- input_image = input_image * 0.5 + 0.5
- input_image = input_image.clamp(0, 1)
- input_image = input_image.cpu().permute(0, 2, 3, 1).float().numpy()
- input_image = DiffusionPipeline.numpy_to_pil(input_image)[0]
-
- return {
- "image": input_image,
- "generator": generator,
- "decoder_num_inference_steps": 2,
- "super_res_num_inference_steps": 2,
- "output_type": "np",
- }
-
- def test_unclip_image_variation_input_tensor(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.0002,
- 0.9997,
- 0.9997,
- 0.9969,
- 0.0023,
- 0.9997,
- 0.9969,
- 0.9970,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_image_variation_input_image(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.9997, 0.0003, 0.9997, 0.9997, 0.9970, 0.0024, 0.9997, 0.9971, 0.9971])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_image_variation_input_list_images(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
- pipeline_inputs["image"] = [
- pipeline_inputs["image"],
- pipeline_inputs["image"],
- ]
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
- tuple_pipeline_inputs["image"] = [
- tuple_pipeline_inputs["image"],
- tuple_pipeline_inputs["image"],
- ]
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (2, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.9989,
- 0.0008,
- 0.0021,
- 0.9960,
- 0.0018,
- 0.0014,
- 0.0002,
- 0.9933,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_passed_image_embed(self):
- device = torch.device("cpu")
-
- class DummyScheduler:
- init_noise_sigma = 1
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device=device).manual_seed(0)
- dtype = pipe.decoder.dtype
- batch_size = 1
-
- shape = (
- batch_size,
- pipe.decoder.config.in_channels,
- pipe.decoder.config.sample_size,
- pipe.decoder.config.sample_size,
- )
- decoder_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- shape = (
- batch_size,
- pipe.super_res_first.config.in_channels // 2,
- pipe.super_res_first.config.sample_size,
- pipe.super_res_first.config.sample_size,
- )
- generator = torch.Generator(device=device).manual_seed(0)
- super_res_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- img_out_1 = pipe(
- **pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
- ).images
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
- # Don't pass image, instead pass embedding
- image = pipeline_inputs.pop("image")
- image_embeddings = pipe.image_encoder(image).image_embeds
-
- img_out_2 = pipe(
- **pipeline_inputs,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- image_embeddings=image_embeddings,
- ).images
-
- # make sure passing text embeddings manually is identical
- assert np.abs(img_out_1 - img_out_2).max() < 1e-4
-
- # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
- # because UnCLIP GPU undeterminism requires a looser check.
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
-
- # Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
- expected_max_diff = 1e-2
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
- )
-
- # Overriding PipelineTesterMixin::test_inference_batch_single_identical
- # because UnCLIP undeterminism requires a looser check.
- @unittest.skip("UnCLIP produces very large differences. Test is not useful.")
- @skip_mps
- def test_inference_batch_single_identical(self):
- additional_params_copy_to_batched_inputs = [
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- self._test_inference_batch_single_identical(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=5e-3
- )
-
- def test_inference_batch_consistent(self):
- additional_params_copy_to_batched_inputs = [
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- if torch_device == "mps":
- # TODO: MPS errors with larger batch sizes
- batch_sizes = [2, 3]
- self._test_inference_batch_consistent(
- batch_sizes=batch_sizes,
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
- )
- else:
- self._test_inference_batch_consistent(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs
- )
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- @unittest.skip("UnCLIP produces very large difference. Test is not useful.")
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local(expected_max_difference=4e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @unittest.skip("UnCLIP produces very large difference in fp16 vs fp32. Test is not useful.")
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1.0)
-
-
-@nightly
-@require_torch_gpu
-class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def test_unclip_image_variation_karlo(self):
- input_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unclip/cat.png"
- )
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
- )
-
- pipeline = UnCLIPImageVariationPipeline.from_pretrained(
- "kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float16
- )
- pipeline = pipeline.to(torch_device)
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipeline(
- input_image,
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
-
- assert_mean_pixel_difference(image, expected_image, 15)
diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py
deleted file mode 100644
index 292978eb6eee..000000000000
--- a/tests/pipelines/unidiffuser/test_unidiffuser.py
+++ /dev/null
@@ -1,812 +0,0 @@
-import gc
-import random
-import traceback
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import (
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionModelWithProjection,
- GPT2Tokenizer,
-)
-
-from diffusers import (
- AutoencoderKL,
- DPMSolverMultistepScheduler,
- UniDiffuserModel,
- UniDiffuserPipeline,
- UniDiffuserTextDecoder,
-)
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- require_torch_2,
- require_torch_accelerator,
- require_torch_gpu,
- run_test_in_subprocess,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
-)
-from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-# Will be run via run_test_in_subprocess
-def _test_unidiffuser_compile(in_queue, out_queue, timeout):
- error = None
- try:
- inputs = in_queue.get(timeout=timeout)
- torch_device = inputs.pop("torch_device")
- seed = inputs.pop("seed")
- inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed)
-
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe = pipe.to(torch_device)
-
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- pipe.set_progress_bar_config(disable=None)
-
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
- assert np.abs(image_slice - expected_slice).max() < 1e-1
- except Exception:
- error = f"{traceback.format_exc()}"
-
- results = {"error": error}
- out_queue.put(results, timeout=timeout)
- out_queue.join()
-
-
-class UniDiffuserPipelineFastTests(
- PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
-):
- pipeline_class = UniDiffuserPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- # vae_latents, not latents, is the argument that corresponds to VAE latent inputs
- image_latents_params = frozenset(["vae_latents"])
-
- supports_dduf = False
-
- def get_dummy_components(self):
- unet = UniDiffuserModel.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="unet",
- )
-
- scheduler = DPMSolverMultistepScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- solver_order=3,
- )
-
- vae = AutoencoderKL.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="vae",
- )
-
- text_encoder = CLIPTextModel.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_encoder",
- )
- clip_tokenizer = CLIPTokenizer.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="clip_tokenizer",
- )
-
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="image_encoder",
- )
- # From the Stable Diffusion Image Variation pipeline tests
- clip_image_processor = CLIPImageProcessor(crop_size=32, size=32)
- # image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_tokenizer = GPT2Tokenizer.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_tokenizer",
- )
- text_decoder = UniDiffuserTextDecoder.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_decoder",
- )
-
- components = {
- "vae": vae,
- "text_encoder": text_encoder,
- "image_encoder": image_encoder,
- "clip_image_processor": clip_image_processor,
- "clip_tokenizer": clip_tokenizer,
- "text_decoder": text_decoder,
- "text_tokenizer": text_tokenizer,
- "unet": unet,
- "scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- generator = torch.Generator(device=device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 32), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 16, 16), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 32), generator=generator, device=device, dtype=torch.float32)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def get_dummy_inputs_with_latents(self, device, seed=0):
- # image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- # image = image.cpu().permute(0, 2, 3, 1)[0]
- # image = Image.fromarray(np.uint8(image)).convert("RGB")
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg",
- )
- image = image.resize((32, 32))
- latents = self.get_fixed_latents(device, seed=seed)
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- "prompt_latents": latents.get("prompt_latents"),
- "vae_latents": latents.get("vae_latents"),
- "clip_latents": latents.get("clip_latents"),
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.7489, 0.3722, 0.4475, 0.5630, 0.5923, 0.4992, 0.3936, 0.5844, 0.4975])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_unidiffuser_default_joint_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_joint_no_cfg_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- # Set guidance scale to 1.0 to turn off CFG
- inputs["guidance_scale"] = 1.0
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_image_0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img'
- unidiffuser_pipe.set_image_mode()
- assert unidiffuser_pipe.mode == "img"
-
- inputs = self.get_dummy_inputs(device)
- # Delete prompt and image for unconditional ("marginal") text generation.
- del inputs["prompt"]
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5760, 0.6270, 0.6571, 0.4966, 0.4638, 0.5663, 0.5254, 0.5068, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_text_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img'
- unidiffuser_pipe.set_text_mode()
- assert unidiffuser_pipe.mode == "text"
-
- inputs = self.get_dummy_inputs(device)
- # Delete prompt and image for unconditional ("marginal") text generation.
- del inputs["prompt"]
- del inputs["image"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_img2text_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_joint_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_img2text_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_text2img_multiple_images(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (2, 32, 32, 3)
-
- def test_unidiffuser_img2text_multiple_prompts(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- text = unidiffuser_pipe(**inputs).text
-
- assert len(text) == 3
-
- def test_unidiffuser_text2img_multiple_images_with_latents(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (2, 32, 32, 3)
-
- def test_unidiffuser_img2text_multiple_prompts_with_latents(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- text = unidiffuser_pipe(**inputs).text
-
- assert len(text) == 3
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=2e-4)
-
- @require_torch_accelerator
- def test_unidiffuser_default_joint_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5049, 0.5498, 0.5854, 0.3052, 0.4460, 0.6489, 0.5122, 0.4810, 0.6138])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = '" This This'
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- @require_torch_accelerator
- def test_unidiffuser_default_text2img_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5054, 0.5498, 0.5854, 0.3052, 0.4458, 0.6489, 0.5122, 0.4810, 0.6138])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- @require_torch_accelerator
- def test_unidiffuser_default_img2text_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- inputs["data_type"] = 1
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = '" This This'
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- @unittest.skip(
- "Test not supported becauseit has a bunch of direct configs at init and also, this pipeline isn't used that much now."
- )
- def test_encode_prompt_works_in_isolation():
- pass
-
-
-@nightly
-@require_torch_gpu
-class UniDiffuserPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, seed=0, generate_latents=False):
- generator = torch.manual_seed(seed)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
- )
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 8.0,
- "output_type": "np",
- }
- if generate_latents:
- latents = self.get_fixed_latents(device, seed=seed)
- for latent_name, latent_tensor in latents.items():
- inputs[latent_name] = latent_tensor
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- latent_device = torch.device("cpu")
- generator = torch.Generator(device=latent_device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32)
-
- # Move latents onto desired device.
- prompt_latents = prompt_latents.to(device)
- vae_latents = vae_latents.to(device)
- clip_latents = clip_latents.to(device)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def test_unidiffuser_default_joint_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-1
-
- expected_text_prefix = "a living room"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
-
- def test_unidiffuser_default_img2text_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["prompt"]
- sample = pipe(**inputs)
- text = sample.text
-
- expected_text_prefix = "An astronaut"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- @unittest.skip(reason="Skip torch.compile test to speed up the slow test suite.")
- @require_torch_2
- def test_unidiffuser_compile(self, seed=0):
- inputs = self.get_inputs(torch_device, seed=seed, generate_latents=True)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- # Can't pickle a Generator object
- del inputs["generator"]
- inputs["torch_device"] = torch_device
- inputs["seed"] = seed
- run_test_in_subprocess(test_case=self, target_func=_test_unidiffuser_compile, inputs=inputs)
-
-
-@nightly
-@require_torch_gpu
-class UniDiffuserPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- torch.cuda.empty_cache()
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- torch.cuda.empty_cache()
-
- def get_inputs(self, device, seed=0, generate_latents=False):
- generator = torch.manual_seed(seed)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
- )
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 8.0,
- "output_type": "np",
- }
- if generate_latents:
- latents = self.get_fixed_latents(device, seed=seed)
- for latent_name, latent_tensor in latents.items():
- inputs[latent_name] = latent_tensor
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- latent_device = torch.device("cpu")
- generator = torch.Generator(device=latent_device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32)
-
- # Move latents onto desired device.
- prompt_latents = prompt_latents.to(device)
- vae_latents = vae_latents.to(device)
- clip_latents = clip_latents.to(device)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def test_unidiffuser_default_joint_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 2e-1
-
- expected_text_prefix = "a living room"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
-
- def test_unidiffuser_default_img2text_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["prompt"]
- sample = pipe(**inputs)
- text = sample.text
-
- expected_text_prefix = "An astronaut"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
diff --git a/tests/pipelines/visualcloze/__init__.py b/tests/pipelines/visualcloze/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py b/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py
new file mode 100644
index 000000000000..00ae0441fe99
--- /dev/null
+++ b/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py
@@ -0,0 +1,344 @@
+import random
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+import diffusers
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, VisualClozePipeline
+from diffusers.utils import logging
+
+from ...testing_utils import (
+ CaptureLogger,
+ enable_full_determinism,
+ floats_tensor,
+ require_accelerator,
+ torch_device,
+)
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class VisualClozePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = VisualClozePipeline
+ params = frozenset(
+ [
+ "task_prompt",
+ "content_prompt",
+ "upsampling_height",
+ "upsampling_width",
+ "guidance_scale",
+ "prompt_embeds",
+ "pooled_prompt_embeds",
+ "upsampling_strength",
+ ]
+ )
+ batch_params = frozenset(["task_prompt", "content_prompt", "image"])
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=12,
+ out_channels=4,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=6,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[2, 2, 2],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ "resolution": 32,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ # Create example images to simulate the input format required by VisualCloze
+ context_image = [
+ Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8))
+ for _ in range(2)
+ ]
+ query_image = [
+ Image.fromarray(
+ floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8)
+ ),
+ None,
+ ]
+
+ # Create an image list that conforms to the VisualCloze input format
+ image = [
+ context_image, # In-Context example
+ query_image, # Query image
+ ]
+
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.",
+ "content_prompt": "A beautiful landscape with mountains and a lake",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "upsampling_height": 32,
+ "upsampling_width": 32,
+ "max_sequence_length": 77,
+ "output_type": "np",
+ "upsampling_strength": 0.4,
+ }
+ return inputs
+
+ def test_visualcloze_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["task_prompt"] = "A different task to perform."
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different
+ assert max_diff > 1e-6
+
+ def test_visualcloze_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.generation_pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.generation_pipe.vae_scale_factor * 2)
+
+ inputs.update({"upsampling_height": height, "upsampling_width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=1e-3)
+
+ def test_upsampling_strength(self, expected_min_diff=1e-1):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test different upsampling strengths
+ inputs["upsampling_strength"] = 0.2
+ output_no_upsampling = pipe(**inputs).images[0]
+
+ inputs["upsampling_strength"] = 0.8
+ output_full_upsampling = pipe(**inputs).images[0]
+
+ # Different upsampling strengths should produce different outputs
+ max_diff = np.abs(output_no_upsampling - output_full_upsampling).max()
+ assert max_diff > expected_min_diff
+
+ def test_different_task_prompts(self, expected_min_diff=1e-1):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ output_original = pipe(**inputs).images[0]
+
+ inputs["task_prompt"] = "A different task description for image generation"
+ output_different_task = pipe(**inputs).images[0]
+
+ # Different task prompts should produce different outputs
+ max_diff = np.abs(output_original - output_different_task).max()
+ assert max_diff > expected_min_diff
+
+ @unittest.skip(
+ "Test not applicable because the pipeline being tested is a wrapper pipeline. CFG tests should be done on the inner pipelines."
+ )
+ def test_callback_cfg(self):
+ pass
+
+ def test_save_load_local(self, expected_max_difference=5e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(diffusers.logging.INFO)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+
+ with CaptureLogger(logger) as cap_logger:
+ # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
+ # This attribute is not serialized in the config of the pipeline
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
+
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ for name in pipe_loaded.components.keys():
+ if name not in pipe_loaded._optional_components:
+ assert name in str(cap_logger)
+
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ if not hasattr(self.pipeline_class, "_optional_components"):
+ return
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # set all optional components to None
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
+ # This attribute is not serialized in the config of the pipeline
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ components = self.get_dummy_components()
+ for name, module in components.items():
+ if hasattr(module, "half"):
+ components[name] = module.to(torch_device).half()
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
+ # This attribute is not serialized in the config of the pipeline
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16, resolution=32)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for name, component in pipe_loaded.components.items():
+ if hasattr(component, "dtype"):
+ self.assertTrue(
+ component.dtype == torch.float16,
+ f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(
+ max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
+ )
diff --git a/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py b/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py
new file mode 100644
index 000000000000..ab6b3ca5c587
--- /dev/null
+++ b/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py
@@ -0,0 +1,312 @@
+import random
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxTransformer2DModel,
+ VisualClozeGenerationPipeline,
+)
+from diffusers.utils import logging
+
+from ...testing_utils import (
+ CaptureLogger,
+ enable_full_determinism,
+ floats_tensor,
+ require_accelerator,
+ torch_device,
+)
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class VisualClozeGenerationPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = VisualClozeGenerationPipeline
+ params = frozenset(
+ [
+ "task_prompt",
+ "content_prompt",
+ "guidance_scale",
+ "prompt_embeds",
+ "pooled_prompt_embeds",
+ ]
+ )
+ batch_params = frozenset(["task_prompt", "content_prompt", "image"])
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=12,
+ out_channels=4,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=6,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[2, 2, 2],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ "resolution": 32,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ # Create example images to simulate the input format required by VisualCloze
+ context_image = [
+ Image.fromarray(floats_tensor((32, 32, 3), rng=random.Random(seed), scale=255).numpy().astype(np.uint8))
+ for _ in range(2)
+ ]
+ query_image = [
+ Image.fromarray(
+ floats_tensor((32, 32, 3), rng=random.Random(seed + 1), scale=255).numpy().astype(np.uint8)
+ ),
+ None,
+ ]
+
+ # Create an image list that conforms to the VisualCloze input format
+ image = [
+ context_image, # In-Context example
+ query_image, # Query image
+ ]
+
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "task_prompt": "Each row outlines a logical process, starting from [IMAGE1] gray-based depth map with detailed object contours, to achieve [IMAGE2] an image with flawless clarity.",
+ "content_prompt": "A beautiful landscape with mountains and a lake",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "max_sequence_length": 77,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_visualcloze_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["task_prompt"] = "A different task to perform."
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different
+ assert max_diff > 1e-6
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=1e-3)
+
+ def test_different_task_prompts(self, expected_min_diff=1e-1):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ output_original = pipe(**inputs).images[0]
+
+ inputs["task_prompt"] = "A different task description for image generation"
+ output_different_task = pipe(**inputs).images[0]
+
+ # Different task prompts should produce different outputs
+ max_diff = np.abs(output_original - output_different_task).max()
+ assert max_diff > expected_min_diff
+
+ def test_save_load_local(self, expected_max_difference=5e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(diffusers.logging.INFO)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+
+ with CaptureLogger(logger) as cap_logger:
+ # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
+ # This attribute is not serialized in the config of the pipeline
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
+
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+
+ for name in pipe_loaded.components.keys():
+ if name not in pipe_loaded._optional_components:
+ assert name in str(cap_logger)
+
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ if not hasattr(self.pipeline_class, "_optional_components"):
+ return
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # set all optional components to None
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
+ # This attribute is not serialized in the config of the pipeline
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, resolution=32)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ components = self.get_dummy_components()
+ for name, module in components.items():
+ if hasattr(module, "half"):
+ components[name] = module.to(torch_device).half()
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ # NOTE: Resolution must be set to 32 for loading otherwise will lead to OOM on CI hardware
+ # This attribute is not serialized in the config of the pipeline
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16, resolution=32)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for name, component in pipe_loaded.components.items():
+ if hasattr(component, "dtype"):
+ self.assertTrue(
+ component.dtype == torch.float16,
+ f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(
+ max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
+ )
+
+ @unittest.skip("Skipped due to missing layout_prompt. Needs further investigation.")
+ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=0.0001, rtol=0.0001):
+ pass
diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py
index a162e6841d2d..106a7b294646 100644
--- a/tests/pipelines/wan/test_wan.py
+++ b/tests/pipelines/wan/test_wan.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
# limitations under the License.
import gc
+import tempfile
import unittest
import numpy as np
@@ -20,16 +21,16 @@
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
+ torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- PipelineTesterMixin,
-)
+from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
@@ -92,6 +93,7 @@ def get_dummy_components(self):
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
+ "transformer_2": None,
}
return components
@@ -125,16 +127,59 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
- expected_video = torch.randn(9, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
+ # _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = "transformer_2"
+
+ components = self.get_dummy_components()
+ components[optional_component] = None
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
@slow
@require_torch_accelerator
@@ -144,12 +189,12 @@ class WanPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented")
def test_Wanx(self):
diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py
new file mode 100644
index 000000000000..56ef5ceb97ed
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_22.py
@@ -0,0 +1,367 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline, WanTransformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ torch.manual_seed(0)
+ transformer_2 = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer_2": transformer_2,
+ "boundary_ratio": 0.875,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(
+ **components,
+ )
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = "transformer"
+
+ components = self.get_dummy_components()
+ components[optional_component] = None
+ components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ self.assertTrue(
+ getattr(pipe_loaded, "transformer") is None,
+ "`transformer` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+
+class Wan225BPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=48,
+ in_channels=12,
+ out_channels=12,
+ is_residual=True,
+ patch_size=2,
+ latents_mean=[0.0] * 48,
+ latents_std=[1.0] * 48,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ scale_factor_spatial=16,
+ scale_factor_temporal=4,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=48,
+ out_channels=48,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer_2": None,
+ "boundary_ratio": None,
+ "expand_timesteps": True,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(
+ **components,
+ )
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([[[0.4814, 0.4298, 0.5094, 0.4289, 0.5061, 0.4301, 0.5043, 0.4284, 0.5375,
+ 0.5965, 0.5527, 0.6014, 0.5228, 0.6076, 0.6644, 0.5651]]])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ f"generated_slice: {generated_slice}, expected_slice: {expected_slice}",
+ )
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_components_function(self):
+ init_components = self.get_dummy_components()
+ init_components.pop("boundary_ratio")
+ init_components.pop("expand_timesteps")
+ pipe = self.pipeline_class(**init_components)
+
+ self.assertTrue(hasattr(pipe, "components"))
+ self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = "transformer_2"
+
+ components = self.get_dummy_components()
+ components[optional_component] = None
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=2e-3)
diff --git a/tests/pipelines/wan/test_wan_22_image_to_video.py b/tests/pipelines/wan/test_wan_22_image_to_video.py
new file mode 100644
index 000000000000..6294d62044f3
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_22_image_to_video.py
@@ -0,0 +1,392 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanImageToVideoPipeline, WanTransformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class Wan22ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ torch.manual_seed(0)
+ transformer_2 = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer_2": transformer_2,
+ "image_encoder": None,
+ "image_processor": None,
+ "boundary_ratio": 0.875,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(
+ **components,
+ )
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4527, 0.4526, 0.4498, 0.4539, 0.4521, 0.4524, 0.4533, 0.4535, 0.5154,
+ 0.5353, 0.5200, 0.5174, 0.5434, 0.5301, 0.5199, 0.5216])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ f"generated_slice: {generated_slice}, expected_slice: {expected_slice}",
+ )
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = ["transformer", "image_encoder", "image_processor"]
+
+ components = self.get_dummy_components()
+ for component in optional_component:
+ components[component] = None
+ components["boundary_ratio"] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for component in optional_component:
+ self.assertTrue(
+ getattr(pipe_loaded, component) is None,
+ f"`{component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+
+class Wan225BImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=48,
+ in_channels=12,
+ out_channels=12,
+ is_residual=True,
+ patch_size=2,
+ latents_mean=[0.0] * 48,
+ latents_std=[1.0] * 48,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ scale_factor_spatial=16,
+ scale_factor_temporal=4,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=48,
+ out_channels=48,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer_2": None,
+ "image_encoder": None,
+ "image_processor": None,
+ "boundary_ratio": None,
+ "expand_timesteps": True,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 32
+ image_width = 32
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(
+ **components,
+ )
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([[0.4833, 0.4305, 0.5100, 0.4299, 0.5056, 0.4298, 0.5052, 0.4332, 0.5550,
+ 0.6092, 0.5536, 0.5928, 0.5199, 0.5864, 0.6705, 0.5493]])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ f"generated_slice: {generated_slice}, expected_slice: {expected_slice}",
+ )
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_components_function(self):
+ init_components = self.get_dummy_components()
+ init_components.pop("boundary_ratio")
+ init_components.pop("expand_timesteps")
+ pipe = self.pipeline_class(**init_components)
+
+ self.assertTrue(hasattr(pipe, "components"))
+ self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = ["transformer_2", "image_encoder", "image_processor"]
+
+ components = self.get_dummy_components()
+ for component in optional_component:
+ components[component] = None
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for component in optional_component:
+ self.assertTrue(
+ getattr(pipe_loaded, component) is None,
+ f"`{component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+
+ @unittest.skip("Test not supported")
+ def test_callback_inputs(self):
+ pass
diff --git a/tests/pipelines/wan/test_wan_animate.py b/tests/pipelines/wan/test_wan_animate.py
new file mode 100644
index 000000000000..d6d1b09f3620
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_animate.py
@@ -0,0 +1,239 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ WanAnimatePipeline,
+ WanAnimateTransformer3DModel,
+)
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanAnimatePipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ channel_sizes = {"4": 16, "8": 16, "16": 16}
+ transformer = WanAnimateTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ latent_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ image_dim=4,
+ rope_max_seq_len=32,
+ motion_encoder_channel_sizes=channel_sizes,
+ motion_encoder_size=16,
+ motion_style_dim=8,
+ motion_dim=4,
+ motion_encoder_dim=16,
+ face_encoder_hidden_dim=16,
+ face_encoder_num_heads=2,
+ inject_face_latents_blocks=2,
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=4,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=4, size=4)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ num_frames = 17
+ height = 16
+ width = 16
+ face_height = 16
+ face_width = 16
+
+ image = Image.new("RGB", (height, width))
+ pose_video = [Image.new("RGB", (height, width))] * num_frames
+ face_video = [Image.new("RGB", (face_height, face_width))] * num_frames
+
+ inputs = {
+ "image": image,
+ "pose_video": pose_video,
+ "face_video": face_video,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "height": height,
+ "width": width,
+ "segment_frame_length": 77, # TODO: can we set this to num_frames?
+ "num_inference_steps": 2,
+ "mode": "animate",
+ "prev_segment_conditioning_frames": 1,
+ "generator": generator,
+ "guidance_scale": 1.0,
+ "output_type": "pt",
+ "max_sequence_length": 16,
+ }
+ return inputs
+
+ def test_inference(self):
+ """Test basic inference in animation mode."""
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames[0]
+ self.assertEqual(video.shape, (17, 3, 16, 16))
+
+ expected_video = torch.randn(17, 3, 16, 16)
+ max_diff = np.abs(video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_inference_replacement(self):
+ """Test the pipeline in replacement mode with background and mask videos."""
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["mode"] = "replace"
+ num_frames = 17
+ height = 16
+ width = 16
+ inputs["background_video"] = [Image.new("RGB", (height, width))] * num_frames
+ inputs["mask_video"] = [Image.new("L", (height, width))] * num_frames
+
+ video = pipe(**inputs).frames[0]
+ self.assertEqual(video.shape, (17, 3, 16, 16))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "Setting the Wan Animate latents to zero at the last denoising step does not guarantee that the output will be"
+ " zero. I believe this is because the latents are further processed in the outer loop where we loop over"
+ " inference segments."
+ )
+ def test_callback_inputs(self):
+ pass
+
+
+@slow
+@require_torch_accelerator
+class WanAnimatePipelineIntegrationTests(unittest.TestCase):
+ prompt = "A painting of a squirrel eating a burger."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_wan_animate(self):
+ pass
diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py
index 53fa37dfae99..07a9142f2553 100644
--- a/tests/pipelines/wan/test_wan_image_to_video.py
+++ b/tests/pipelines/wan/test_wan_image_to_video.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team.
+# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import tempfile
import unittest
import numpy as np
@@ -26,8 +27,8 @@
)
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -110,6 +111,7 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
+ "transformer_2": None,
}
return components
@@ -147,11 +149,189 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass
+
+ # _optional_components include transformer, transformer_2 and image_encoder, image_processor, but only transformer_2 is optional for wan2.1 i2v pipeline
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = "transformer_2"
+
+ components = self.get_dummy_components()
+ components[optional_component] = None
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+
+class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ # TODO: impl FlowDPMSolverMultistepScheduler
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ pos_embed_seq_len=2 * (4 * 4 + 1),
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=4,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=4, size=4)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ "transformer_2": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ last_image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "last_image": last_image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
- expected_video = torch.randn(9, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
@@ -160,3 +340,42 @@ def test_attention_slicing_forward_pass(self):
@unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
def test_inference_batch_single_identical(self):
pass
+
+ # _optional_components include transformer, transformer_2 and image_encoder, image_processor, but only transformer_2 is optional for wan2.1 FLFT2V pipeline
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = "transformer_2"
+
+ components = self.get_dummy_components()
+ components[optional_component] = None
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
diff --git a/tests/pipelines/wan/test_wan_vace.py b/tests/pipelines/wan/test_wan_vace.py
new file mode 100644
index 000000000000..fe078c0deb8a
--- /dev/null
+++ b/tests/pipelines/wan/test_wan_vace.py
@@ -0,0 +1,299 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ UniPCMultistepScheduler,
+ WanVACEPipeline,
+ WanVACETransformer3DModel,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanVACEPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = WanVACETransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=3,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ vace_layers=[0, 2],
+ vace_in_channels=96,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer_2": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ num_frames = 17
+ height = 16
+ width = 16
+
+ video = [Image.new("RGB", (height, width))] * num_frames
+ mask = [Image.new("L", (height, width), 0)] * num_frames
+
+ inputs = {
+ "video": video,
+ "mask": mask,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": num_frames,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames[0]
+ self.assertEqual(video.shape, (17, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = [0.4523, 0.45198, 0.44872, 0.45326, 0.45211, 0.45258, 0.45344, 0.453, 0.52431, 0.52572, 0.50701, 0.5118, 0.53717, 0.53093, 0.50557, 0.51402]
+ # fmt: on
+
+ video_slice = video.flatten()
+ video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
+ video_slice = [round(x, 5) for x in video_slice.tolist()]
+ self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
+
+ def test_inference_with_single_reference_image(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["reference_images"] = Image.new("RGB", (16, 16))
+ video = pipe(**inputs).frames[0]
+ self.assertEqual(video.shape, (17, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = [0.45247, 0.45214, 0.44874, 0.45314, 0.45171, 0.45299, 0.45428, 0.45317, 0.51378, 0.52658, 0.53361, 0.52303, 0.46204, 0.50435, 0.52555, 0.51342]
+ # fmt: on
+
+ video_slice = video.flatten()
+ video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
+ video_slice = [round(x, 5) for x in video_slice.tolist()]
+ self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
+
+ def test_inference_with_multiple_reference_image(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["reference_images"] = [[Image.new("RGB", (16, 16))] * 2]
+ video = pipe(**inputs).frames[0]
+ self.assertEqual(video.shape, (17, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = [0.45321, 0.45221, 0.44818, 0.45375, 0.45268, 0.4519, 0.45271, 0.45253, 0.51244, 0.52223, 0.51253, 0.51321, 0.50743, 0.51177, 0.51626, 0.50983]
+ # fmt: on
+
+ video_slice = video.flatten()
+ video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
+ video_slice = [round(x, 5) for x in video_slice.tolist()]
+ self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ @unittest.skip("Batching is not yet supported with this pipeline")
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip("Batching is not yet supported with this pipeline")
+ def test_inference_batch_single_identical(self):
+ return super().test_inference_batch_single_identical()
+
+ @unittest.skip(
+ "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
+ )
+ def test_float16_inference(self):
+ pass
+
+ @unittest.skip(
+ "AutoencoderKLWan encoded latents are always in FP32. This test is not designed to handle mixed dtype inputs"
+ )
+ def test_save_load_float16(self):
+ pass
+
+ def test_inference_with_only_transformer(self):
+ components = self.get_dummy_components()
+ components["transformer_2"] = None
+ components["boundary_ratio"] = 0.0
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ video = pipe(**inputs).frames[0]
+ assert video.shape == (17, 3, 16, 16)
+
+ def test_inference_with_only_transformer_2(self):
+ components = self.get_dummy_components()
+ components["transformer_2"] = components["transformer"]
+ components["transformer"] = None
+
+ # FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
+ # because starting timestep t == 1000 == boundary_timestep
+ components["scheduler"] = UniPCMultistepScheduler(
+ prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
+ )
+
+ components["boundary_ratio"] = 1.0
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ video = pipe(**inputs).frames[0]
+ assert video.shape == (17, 3, 16, 16)
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = ["transformer"]
+
+ components = self.get_dummy_components()
+ components["transformer_2"] = components["transformer"]
+ # FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
+ # because starting timestep t == 1000 == boundary_timestep
+ components["scheduler"] = UniPCMultistepScheduler(
+ prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
+ )
+ for component in optional_component:
+ components[component] = None
+
+ components["boundary_ratio"] = 1.0
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for component in optional_component:
+ assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"
diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py
index 11c748424a30..27ada121ca48 100644
--- a/tests/pipelines/wan/test_wan_video_to_video.py
+++ b/tests/pipelines/wan/test_wan_video_to_video.py
@@ -14,16 +14,15 @@
import unittest
-import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanTransformer3DModel, WanVideoToVideoPipeline
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
@@ -123,11 +122,15 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
- expected_video = torch.randn(17, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195])
+ # fmt:on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
deleted file mode 100644
index 084d62a8c613..000000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline
-from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenCombinedPipeline
- params = ["prompt"]
- batch_params = ["prompt", "negative_prompt"]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "prior_guidance_scale",
- "decoder_guidance_scale",
- "negative_prompt",
- "num_inference_steps",
- "return_dict",
- "prior_num_inference_steps",
- "output_type",
- ]
- test_xformers_attention = True
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {"c_in": 2, "c": 8, "depth": 2, "c_cond": 32, "c_r": 8, "nhead": 2}
- model = WuerstchenPrior(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_prior_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- projection_dim=self.text_embedder_hidden_size,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_vqgan(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "bottleneck_blocks": 1,
- "num_vq_embeddings": 2,
- }
- model = PaellaVQModel(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_cond": self.text_embedder_hidden_size,
- "c_hidden": [320],
- "nhead": [-1],
- "blocks": [4],
- "level_config": ["CT"],
- "clip_embd": self.text_embedder_hidden_size,
- "inject_effnet": [False],
- }
-
- model = WuerstchenDiffNeXt(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- prior_text_encoder = self.dummy_prior_text_encoder
-
- scheduler = DDPMWuerstchenScheduler()
- tokenizer = self.dummy_tokenizer
-
- text_encoder = self.dummy_text_encoder
- decoder = self.dummy_decoder
- vqgan = self.dummy_vqgan
-
- components = {
- "tokenizer": tokenizer,
- "text_encoder": text_encoder,
- "decoder": decoder,
- "vqgan": vqgan,
- "scheduler": scheduler,
- "prior_prior": prior,
- "prior_text_encoder": prior_text_encoder,
- "prior_tokenizer": tokenizer,
- "prior_scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "prior_guidance_scale": 4.0,
- "decoder_guidance_scale": 4.0,
- "num_inference_steps": 2,
- "prior_num_inference_steps": 2,
- "output_type": "np",
- "height": 128,
- "width": 128,
- }
- return inputs
-
- def test_wuerstchen(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
-
- assert image.shape == (1, 128, 128, 3)
-
- expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
-
- assert (
- np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- assert (
- np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
- ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
-
- @require_torch_accelerator
- def test_offloads(self):
- pipes = []
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- image_slices = []
- for pipe in pipes:
- inputs = self.get_dummy_inputs(torch_device)
- image = pipe(**inputs).images
-
- image_slices.append(image[0, -3:, -3:, -1].flatten())
-
- assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
- assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=1e-2)
-
- @unittest.skip(reason="flakey and float16 requires CUDA")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- @unittest.skip(reason="Test not supported.")
- def test_callback_inputs(self):
- pass
-
- @unittest.skip(reason="Test not supported.")
- def test_callback_cfg(self):
- pass
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
deleted file mode 100644
index 97d1a1cc3830..000000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline
-from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenDecoderPipeline
- params = ["prompt"]
- batch_params = ["image_embeddings", "prompt", "negative_prompt"]
- required_optional_params = [
- "num_images_per_prompt",
- "num_inference_steps",
- "latents",
- "negative_prompt",
- "guidance_scale",
- "output_type",
- "return_dict",
- ]
- test_xformers_attention = False
- callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- projection_dim=self.text_embedder_hidden_size,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_vqgan(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "bottleneck_blocks": 1,
- "num_vq_embeddings": 2,
- }
- model = PaellaVQModel(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_cond": self.text_embedder_hidden_size,
- "c_hidden": [320],
- "nhead": [-1],
- "blocks": [4],
- "level_config": ["CT"],
- "clip_embd": self.text_embedder_hidden_size,
- "inject_effnet": [False],
- }
-
- model = WuerstchenDiffNeXt(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- decoder = self.dummy_decoder
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- vqgan = self.dummy_vqgan
-
- scheduler = DDPMWuerstchenScheduler()
-
- components = {
- "decoder": decoder,
- "vqgan": vqgan,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "latent_dim_scale": 4.0,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image_embeddings": torch.ones((1, 4, 4, 4), device=device),
- "prompt": "horse",
- "generator": generator,
- "guidance_scale": 1.0,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_wuerstchen_decoder(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.0000, 0.0000, 0.0089, 1.0000, 1.0000, 0.3927, 1.0000, 1.0000, 1.0000])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- @skip_mps
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=1e-5)
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
- test_mean_pixel_difference = False
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference,
- test_mean_pixel_difference=test_mean_pixel_difference,
- )
-
- @unittest.skip(reason="bf16 not supported and requires CUDA")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- @unittest.skip("Test not supoorted.")
- def test_encode_prompt_works_in_isolation(self):
- super().test_encode_prompt_works_in_isolation()
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
deleted file mode 100644
index 4bc086e7f65b..000000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
-from diffusers.pipelines.wuerstchen import WuerstchenPrior
-from diffusers.utils.import_utils import is_peft_available
-from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device
-
-
-if is_peft_available():
- from peft import LoraConfig
- from peft.tuners.tuners_utils import BaseTunerLayer
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenPriorPipeline
- params = ["prompt"]
- batch_params = ["prompt", "negative_prompt"]
- required_optional_params = [
- "num_images_per_prompt",
- "generator",
- "num_inference_steps",
- "latents",
- "negative_prompt",
- "guidance_scale",
- "output_type",
- "return_dict",
- ]
- test_xformers_attention = False
- callback_cfg_params = ["text_encoder_hidden_states"]
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_in": 2,
- "c": 8,
- "depth": 2,
- "c_cond": 32,
- "c_r": 8,
- "nhead": 2,
- }
-
- model = WuerstchenPrior(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
-
- scheduler = DDPMWuerstchenScheduler()
-
- components = {
- "prior": prior,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "guidance_scale": 4.0,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_wuerstchen_prior(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.image_embeddings
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
-
- image_slice = image[0, 0, 0, -10:]
- image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
- assert image.shape == (1, 2, 24, 24)
-
- expected_slice = np.array(
- [
- -7172.837,
- -3438.855,
- -1093.312,
- 388.8835,
- -7471.467,
- -7998.1206,
- -5328.259,
- 218.00089,
- -2731.5745,
- -8056.734,
- ]
- )
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-2
-
- @skip_mps
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(
- expected_max_diff=3e-1,
- )
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
- test_mean_pixel_difference = False
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference,
- test_mean_pixel_difference=test_mean_pixel_difference,
- )
-
- @unittest.skip(reason="flaky for now")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- # override because we need to make sure latent_mean and latent_std to be 0
- def test_callback_inputs(self):
- components = self.get_dummy_components()
- components["latent_mean"] = 0
- components["latent_std"] = 0
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- self.assertTrue(
- hasattr(pipe, "_callback_tensor_inputs"),
- f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
- )
-
- def callback_inputs_test(pipe, i, t, callback_kwargs):
- missing_callback_inputs = set()
- for v in pipe._callback_tensor_inputs:
- if v not in callback_kwargs:
- missing_callback_inputs.add(v)
- self.assertTrue(
- len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
- )
- last_i = pipe.num_timesteps - 1
- if i == last_i:
- callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["callback_on_step_end"] = callback_inputs_test
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- inputs["output_type"] = "latent"
-
- output = pipe(**inputs)[0]
- assert output.abs().sum() == 0
-
- def check_if_lora_correctly_set(self, model) -> bool:
- """
- Checks if the LoRA layers are correctly set with peft
- """
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- return True
- return False
-
- def get_lora_components(self):
- prior = self.dummy_prior
-
- prior_lora_config = LoraConfig(
- r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
- )
-
- return prior, prior_lora_config
-
- @require_peft_backend
- def test_inference_with_prior_lora(self):
- _, prior_lora_config = self.get_lora_components()
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output_no_lora = pipe(**self.get_dummy_inputs(device))
- image_embed = output_no_lora.image_embeddings
- self.assertTrue(image_embed.shape == (1, 2, 24, 24))
-
- pipe.prior.add_adapter(prior_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
-
- output_lora = pipe(**self.get_dummy_inputs(device))
- lora_image_embed = output_lora.image_embeddings
-
- self.assertTrue(image_embed.shape == lora_image_embed.shape)
-
- @unittest.skip("Test not supported as dtype cannot be inferred without the text encoder otherwise.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/z_image/__init__.py b/tests/pipelines/z_image/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/z_image/test_z_image.py b/tests/pipelines/z_image/test_z_image.py
new file mode 100644
index 000000000000..79a5fa0de5f0
--- /dev/null
+++ b/tests/pipelines/z_image/test_z_image.py
@@ -0,0 +1,307 @@
+# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
+
+from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
+
+from ...testing_utils import torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
+# Cannot use enable_full_determinism() which sets it to True
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+torch.use_deterministic_algorithms(False)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+if hasattr(torch.backends, "cuda"):
+ torch.backends.cuda.matmul.allow_tf32 = False
+
+# Note: Some tests (test_float16_inference, test_save_load_float16) may fail in full suite
+# due to RopeEmbedder cache state pollution between tests. They pass when run individually.
+# This is a known test isolation issue, not a functional bug.
+
+
+class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = ZImagePipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def setUp(self):
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = ZImageTransformer2DModel(
+ all_patch_size=(2,),
+ all_f_patch_size=(1,),
+ in_channels=16,
+ dim=32,
+ n_layers=2,
+ n_refiner_layers=1,
+ n_heads=2,
+ n_kv_heads=2,
+ norm_eps=1e-5,
+ qk_norm=True,
+ cap_feat_dim=16,
+ rope_theta=256.0,
+ t_scale=1000.0,
+ axes_dims=[8, 4, 4],
+ axes_lens=[256, 32, 32],
+ )
+ # `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
+ # This can cause NaN data values in our testing environment. Fixating them
+ # helps prevent that issue.
+ with torch.no_grad():
+ transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
+ transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ block_out_channels=[32, 64],
+ layers_per_block=1,
+ latent_channels=16,
+ norm_num_groups=32,
+ sample_size=32,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen3Config(
+ hidden_size=16,
+ intermediate_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ vocab_size=151936,
+ max_position_embeddings=512,
+ )
+ text_encoder = Qwen3Model(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "cfg_normalization": False,
+ "cfg_truncation": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4622, 0.4532, 0.4714, 0.5087, 0.5371, 0.5405, 0.4492, 0.4479, 0.2984, 0.2783, 0.5409, 0.6577, 0.3952, 0.5524, 0.5262, 0.453])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=5e-2))
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_num_images_per_prompt(self):
+ import inspect
+
+ sig = inspect.signature(self.pipeline_class.__call__)
+
+ if "num_images_per_prompt" not in sig.parameters:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ batch_sizes = [1, 2]
+ num_images_per_prompts = [1, 2]
+
+ for batch_size in batch_sizes:
+ for num_images_per_prompt in num_images_per_prompts:
+ inputs = self.get_dummy_inputs(torch_device)
+
+ for key in inputs.keys():
+ if key in self.batch_params:
+ inputs[key] = batch_size * [inputs[key]]
+
+ images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
+
+ assert images.shape[0] == batch_size * num_images_per_prompt
+
+ del pipe
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling (standard AutoencoderKL doesn't accept parameters)
+ pipe.vae.enable_tiling()
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=5e-4):
+ # Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
+ super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
+
+ def test_group_offloading_inference(self):
+ # Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
+ self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
+
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+ super().test_save_load_float16(expected_max_diff=expected_max_diff)
diff --git a/tests/pipelines/z_image/test_z_image_img2img.py b/tests/pipelines/z_image/test_z_image_img2img.py
new file mode 100644
index 000000000000..91b3025b17e8
--- /dev/null
+++ b/tests/pipelines/z_image/test_z_image_img2img.py
@@ -0,0 +1,358 @@
+# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import os
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ ZImageImg2ImgPipeline,
+ ZImageTransformer2DModel,
+)
+from diffusers.utils.testing_utils import floats_tensor
+
+from ...testing_utils import torch_device
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
+ TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
+)
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
+# Cannot use enable_full_determinism() which sets it to True
+# Note: Z-Image does not support FP16 inference due to complex64 RoPE embeddings
+os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+torch.use_deterministic_algorithms(False)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+if hasattr(torch.backends, "cuda"):
+ torch.backends.cuda.matmul.allow_tf32 = False
+
+
+class ZImageImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = ZImageImg2ImgPipeline
+ params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "strength",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def setUp(self):
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = ZImageTransformer2DModel(
+ all_patch_size=(2,),
+ all_f_patch_size=(1,),
+ in_channels=16,
+ dim=32,
+ n_layers=2,
+ n_refiner_layers=1,
+ n_heads=2,
+ n_kv_heads=2,
+ norm_eps=1e-5,
+ qk_norm=True,
+ cap_feat_dim=16,
+ rope_theta=256.0,
+ t_scale=1000.0,
+ axes_dims=[8, 4, 4],
+ axes_lens=[256, 32, 32],
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ block_out_channels=[32, 64],
+ layers_per_block=1,
+ latent_channels=16,
+ norm_num_groups=32,
+ sample_size=32,
+ scaling_factor=0.3611,
+ shift_factor=0.1159,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen3Config(
+ hidden_size=16,
+ intermediate_size=16,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ num_key_value_heads=2,
+ vocab_size=151936,
+ max_position_embeddings=512,
+ )
+ text_encoder = Qwen3Model(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ import random
+
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "image": image,
+ "strength": 0.6,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "cfg_normalization": False,
+ "cfg_truncation": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "np",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (32, 32, 3))
+
+ def test_inference_batch_single_identical(self):
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(0)
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_num_images_per_prompt(self):
+ import inspect
+
+ sig = inspect.signature(self.pipeline_class.__call__)
+
+ if "num_images_per_prompt" not in sig.parameters:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ batch_sizes = [1, 2]
+ num_images_per_prompts = [1, 2]
+
+ for batch_size in batch_sizes:
+ for num_images_per_prompt in num_images_per_prompts:
+ inputs = self.get_dummy_inputs(torch_device)
+
+ for key in inputs.keys():
+ if key in self.batch_params:
+ inputs[key] = batch_size * [inputs[key]]
+
+ images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
+
+ assert images.shape[0] == batch_size * num_images_per_prompt
+
+ del pipe
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.3):
+ import random
+
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ # Generate a larger image for the input
+ inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling (standard AutoencoderKL doesn't accept parameters)
+ pipe.vae.enable_tiling()
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=5e-4):
+ # Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
+ super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
+
+ def test_group_offloading_inference(self):
+ # Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
+ self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
+
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ # Z-Image does not support FP16 due to complex64 RoPE embeddings
+ self.skipTest("Z-Image does not support FP16 inference")
+
+ def test_float16_inference(self, expected_max_diff=5e-2):
+ # Z-Image does not support FP16 due to complex64 RoPE embeddings
+ self.skipTest("Z-Image does not support FP16 inference")
+
+ def test_strength_parameter(self):
+ """Test that strength parameter affects the output correctly."""
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # Test with different strength values
+ inputs_low_strength = self.get_dummy_inputs(device)
+ inputs_low_strength["strength"] = 0.2
+
+ inputs_high_strength = self.get_dummy_inputs(device)
+ inputs_high_strength["strength"] = 0.8
+
+ # Both should complete without errors
+ output_low = pipe(**inputs_low_strength).images[0]
+ output_high = pipe(**inputs_high_strength).images[0]
+
+ # Outputs should be different (different amount of transformation)
+ self.assertFalse(np.allclose(output_low, output_high, atol=1e-3))
+
+ def test_invalid_strength(self):
+ """Test that invalid strength values raise appropriate errors."""
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+
+ inputs = self.get_dummy_inputs(device)
+
+ # Test strength < 0
+ inputs["strength"] = -0.1
+ with self.assertRaises(ValueError):
+ pipe(**inputs)
+
+ # Test strength > 1
+ inputs["strength"] = 1.5
+ with self.assertRaises(ValueError):
+ pipe(**inputs)
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index 29a3e212c48d..c1da8f1ece78 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Team Inc.
+# Copyright 2025 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,10 +21,19 @@
import pytest
import safetensors.torch
from huggingface_hub import hf_hub_download
-
-from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
+from PIL import Image
+
+from diffusers import (
+ BitsAndBytesConfig,
+ DiffusionPipeline,
+ FluxControlPipeline,
+ FluxTransformer2DModel,
+ SD3Transformer2DModel,
+)
+from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version, logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available,
@@ -37,10 +46,12 @@
require_peft_backend,
require_torch,
require_torch_accelerator,
+ require_torch_version_greater,
require_transformers_version_greater,
slow,
torch_device,
)
+from ..test_torch_compile_utils import QuantCompileTests
def get_some_linear_layer(model):
@@ -63,6 +74,8 @@ def get_some_linear_layer(model):
if is_bitsandbytes_available():
import bitsandbytes as bnb
+ from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear
+
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@@ -83,6 +96,17 @@ class Base4bitTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
+ @classmethod
+ def setUpClass(cls):
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
+
def get_dummy_inputs(self):
prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
@@ -196,7 +220,7 @@ def test_model_memory_usage(self):
def test_original_dtype(self):
r"""
- A simple test to check if the model succesfully stores the original dtype
+ A simple test to check if the model successfully stores the original dtype
"""
self.assertTrue("_pre_quantization_dtype" in self.model_4bit.config)
self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config)
@@ -364,11 +388,23 @@ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
assert key_to_target in str(err_context.exception)
+ def test_bnb_4bit_logs_warning_for_no_quantization(self):
+ model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+ logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
+ logger.setLevel(30)
+ with CaptureLogger(logger) as cap_logger:
+ _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
+ assert (
+ "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
+ in cap_logger.out
+ )
+
class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
@@ -505,7 +541,7 @@ def test_moving_to_cpu_throws_warning(self):
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
- def test_pipeline_device_placement_works_with_nf4(self):
+ def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
@@ -539,7 +575,7 @@ def test_pipeline_device_placement_works_with_nf4(self):
).to(torch_device)
# Check if inference works.
- _ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)
+ _ = pipeline_4bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
del pipeline_4bit
@@ -636,7 +672,7 @@ def get_dummy_tensor_inputs(device=None, seed: int = 0):
class SlowBnb4BitFluxTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
@@ -653,7 +689,7 @@ def tearDown(self):
del self.pipeline_4bit
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def test_quality(self):
# keep the resolution and max tokens to a lower number for faster execution.
@@ -696,6 +732,42 @@ def test_lora_loading(self):
self.assertTrue(max_diff < 1e-3)
+@require_transformers_version_greater("4.44.0")
+@require_peft_backend
+class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests):
+ def setUp(self) -> None:
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16)
+ self.pipeline_4bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_4bit
+
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_lora_loading(self):
+ self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
+
+ output = self.pipeline_4bit(
+ prompt=self.prompt,
+ control_image=Image.new(mode="RGB", size=(256, 256)),
+ height=256,
+ width=256,
+ max_sequence_length=64,
+ output_type="np",
+ num_inference_steps=8,
+ generator=torch.Generator().manual_seed(42),
+ ).images
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
+
+
@slow
class BaseBnb4BitSerializationTests(Base4bitTests):
def tearDown(self):
@@ -797,3 +869,27 @@ def test_fp4_double_unsafe(self):
def test_fp4_double_safe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
+
+
+@require_torch_version_greater("2.7.1")
+@require_bitsandbytes_version_greater("0.45.5")
+class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
+ @property
+ def quantization_config(self):
+ return PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={
+ "load_in_4bit": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": torch.bfloat16,
+ },
+ components_to_quantize=["transformer", "text_encoder_2"],
+ )
+
+ @require_bitsandbytes_version_greater("0.46.1")
+ def test_torch_compile(self):
+ torch._dynamo.config.capture_dynamic_output_shape_ops = True
+ super().test_torch_compile()
+
+ def test_torch_compile_with_group_offload_leaf(self):
+ super()._test_torch_compile_with_group_offload_leaf(use_stream=True)
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index 8809bac25f58..fde3966dec97 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Team Inc.
+# Copyright 2025 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,17 +19,21 @@
import numpy as np
import pytest
from huggingface_hub import hf_hub_download
+from PIL import Image
from diffusers import (
BitsAndBytesConfig,
DiffusionPipeline,
+ FluxControlPipeline,
FluxTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,
logging,
)
+from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available,
@@ -39,13 +43,16 @@
numpy_cosine_similarity_distance,
require_accelerate,
require_bitsandbytes_version_greater,
+ require_peft_backend,
require_peft_version_greater,
require_torch,
require_torch_accelerator,
+ require_torch_version_greater_equal,
require_transformers_version_greater,
slow,
torch_device,
)
+from ..test_torch_compile_utils import QuantCompileTests
def get_some_linear_layer(model):
@@ -68,6 +75,8 @@ def get_some_linear_layer(model):
if is_bitsandbytes_available():
import bitsandbytes as bnb
+ from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear
+
@require_bitsandbytes_version_greater("0.43.2")
@require_accelerate
@@ -88,6 +97,17 @@ class Base8bitTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
+ @classmethod
+ def setUpClass(cls):
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
+
def get_dummy_inputs(self):
prompt_embeds = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt",
@@ -193,7 +213,7 @@ def test_model_memory_usage(self):
def test_original_dtype(self):
r"""
- A simple test to check if the model succesfully stores the original dtype
+ A simple test to check if the model successfully stores the original dtype
"""
self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config)
self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config)
@@ -317,6 +337,18 @@ def test_device_and_dtype_assignment(self):
# Check that this does not throw an error
_ = self.model_fp16.to(torch_device)
+ def test_bnb_8bit_logs_warning_for_no_quantization(self):
+ model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU())
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+ logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils")
+ logger.setLevel(30)
+ with CaptureLogger(logger) as cap_logger:
+ _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config)
+ assert (
+ "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
+ in cap_logger.out
+ )
+
class Bnb8bitDeviceTests(Base8bitTests):
def setUp(self) -> None:
@@ -379,7 +411,7 @@ def test_training(self):
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
# Step 4: Check if the gradient is not None
- with torch.amp.autocast("cuda", dtype=torch.float16):
+ with torch.amp.autocast(torch_device, dtype=torch.float16):
out = self.model_8bit(**model_inputs)[0]
out.norm().backward()
@@ -478,7 +510,7 @@ def test_generate_quality_dequantize(self):
self.assertTrue(max_diff < 1e-2)
# 8bit models cannot be offloaded to CPU.
- self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda")
+ self.assertTrue(self.pipeline_8bit.transformer.device.type == torch_device)
# calling it again shouldn't be a problem
_ = self.pipeline_8bit(
prompt=self.prompt,
@@ -509,16 +541,18 @@ def test_pipeline_cuda_placement_works_with_mixed_int8(self):
torch_dtype=torch.float16,
device_map=torch_device,
)
+
# CUDA device placement works.
+ device = torch_device if torch_device != "rocm" else "cuda"
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_8bit,
text_encoder_3=text_encoder_3_8bit,
torch_dtype=torch.float16,
- ).to("cuda")
+ ).to(device)
# Check if inference works.
- _ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)
+ _ = pipeline_8bit(self.prompt, max_sequence_length=20, num_inference_steps=2)
del pipeline_8bit
@@ -680,6 +714,50 @@ def test_lora_loading(self):
self.assertTrue(max_diff < 1e-3)
+@require_transformers_version_greater("4.44.0")
+@require_peft_backend
+class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests):
+ def setUp(self) -> None:
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ self.pipeline_8bit = FluxControlPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantization_config=PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=["transformer", "text_encoder_2"],
+ ),
+ torch_dtype=torch.float16,
+ )
+ self.pipeline_8bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_8bit
+
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_lora_loading(self):
+ self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
+
+ output = self.pipeline_8bit(
+ prompt=self.prompt,
+ control_image=Image.new(mode="RGB", size=(256, 256)),
+ height=256,
+ width=256,
+ max_sequence_length=64,
+ output_type="np",
+ num_inference_steps=8,
+ generator=torch.Generator().manual_seed(42),
+ ).images
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")
+
+
@slow
class BaseBnb8bitSerializationTests(Base8bitTests):
def setUp(self):
@@ -756,3 +834,30 @@ def test_serialization_sharded(self):
out_0 = self.model_0(**inputs)[0]
out_1 = model_1(**inputs)[0]
self.assertTrue(torch.equal(out_0, out_1))
+
+
+@require_torch_version_greater_equal("2.6.0")
+@require_bitsandbytes_version_greater("0.45.5")
+class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
+ @property
+ def quantization_config(self):
+ return PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=["transformer", "text_encoder_2"],
+ )
+
+ @pytest.mark.xfail(
+ reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
+ " Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
+ )
+ def test_torch_compile(self):
+ torch._dynamo.config.capture_dynamic_output_shape_ops = True
+ super()._test_torch_compile(torch_dtype=torch.float16)
+
+ def test_torch_compile_with_cpu_offload(self):
+ super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
+
+ @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
+ def test_torch_compile_with_group_offload_leaf(self):
+ super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
diff --git a/tests/quantization/gguf/__init__.py b/tests/quantization/gguf/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index 5e3875c7c9cb..b42764be10d6 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -8,29 +8,104 @@
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
+ DiffusionPipeline,
+ FluxControlPipeline,
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
+ HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
+ WanAnimateTransformer3DModel,
+ WanTransformer3DModel,
+ WanVACETransformer3DModel,
)
-from diffusers.utils.testing_utils import (
+from diffusers.utils import load_image
+
+from ...testing_utils import (
+ Expectations,
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ enable_full_determinism,
is_gguf_available,
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
- require_big_gpu_with_torch_cuda,
+ require_accelerator,
+ require_big_accelerator,
require_gguf_version_greater_or_equal,
+ require_kernels_version_greater_or_equal,
+ require_peft_backend,
+ require_torch_version_greater,
torch_device,
)
+from ..test_torch_compile_utils import QuantCompileTests
if is_gguf_available():
+ import gguf
+
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
+enable_full_determinism()
+
@nightly
-@require_big_gpu_with_torch_cuda
+@require_accelerate
+@require_accelerator
+@require_gguf_version_greater_or_equal("0.10.0")
+@require_kernels_version_greater_or_equal("0.9.0")
+class GGUFCudaKernelsTests(unittest.TestCase):
+ def setUp(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_cuda_kernels_vs_native(self):
+ if torch_device != "cuda":
+ self.skipTest("CUDA kernels test requires CUDA device")
+
+ from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
+
+ if not can_use_cuda_kernels:
+ self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
+
+ test_quant_types = ["Q4_0", "Q4_K"]
+ test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
+ compute_dtype = torch.bfloat16
+
+ for quant_type in test_quant_types:
+ qtype = getattr(gguf.GGMLQuantizationType, quant_type)
+ in_features, out_features = 512, 512
+
+ torch.manual_seed(42)
+ float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
+ quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
+ weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
+ weight = GGUFParameter(weight_data, quant_type=qtype)
+
+ x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
+
+ linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
+ linear.weight = weight
+ linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
+ linear = linear.to(torch_device)
+
+ with torch.no_grad():
+ output_native = linear.forward_native(x)
+ output_cuda = linear.forward_cuda(x)
+
+ assert torch.allclose(output_native, output_cuda, 1e-2), (
+ f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
+ )
+
+
+@nightly
+@require_big_accelerator
@require_accelerate
@require_gguf_version_greater_or_equal("0.10.0")
class GGUFSingleFileTesterMixin:
@@ -65,15 +140,15 @@ def test_gguf_memory_usage(self):
model = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
)
- model.to("cuda")
+ model.to(torch_device)
assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb
inputs = self.get_dummy_inputs()
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.empty_cache()
+ backend_reset_peak_memory_stats(torch_device)
+ backend_empty_cache(torch_device)
with torch.no_grad():
model(**inputs)
- max_memory = torch.cuda.max_memory_allocated()
+ max_memory = backend_max_memory_allocated(torch_device)
assert (max_memory / 1024**3) < self.expected_memory_use_in_gb
def test_keep_modules_in_fp32(self):
@@ -103,7 +178,8 @@ def test_dtype_assignment(self):
with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
- model.to(device="cuda:0", dtype=torch.float16)
+ device_0 = f"{torch_device}:0"
+ model.to(device=device_0, dtype=torch.float16)
with self.assertRaises(ValueError):
# Tries with a cast
@@ -114,7 +190,7 @@ def test_dtype_assignment(self):
model.half()
# This should work
- model.to("cuda")
+ model.to(torch_device)
def test_dequantize_model(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
@@ -137,17 +213,18 @@ def _check_for_gguf_linear(model):
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
+ diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
torch_dtype = torch.bfloat16
model_cls = FluxTransformer2DModel
expected_memory_use_in_gb = 5
def setUp(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_dummy_inputs(self):
return {
@@ -221,6 +298,16 @@ def test_pipeline_inference(self):
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
+ def test_loading_gguf_diffusers_format(self):
+ model = self.model_cls.from_single_file(
+ self.diffusers_ckpt_path,
+ subfolder="transformer",
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ config="black-forest-labs/FLUX.1-dev",
+ )
+ model.to(torch_device)
+ model(**self.get_dummy_inputs())
+
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
@@ -230,11 +317,11 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase)
def setUp(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_dummy_inputs(self):
return {
@@ -264,40 +351,79 @@ def test_pipeline_inference(self):
prompt = "a cat holding a sign that says hello"
output = pipe(
- prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
+ prompt=prompt,
+ num_inference_steps=2,
+ generator=torch.Generator("cpu").manual_seed(0),
+ output_type="np",
).images[0]
output_slice = output[:3, :3, :].flatten()
- expected_slice = np.array(
- [
- 0.17578125,
- 0.27539062,
- 0.27734375,
- 0.11914062,
- 0.26953125,
- 0.25390625,
- 0.109375,
- 0.25390625,
- 0.25,
- 0.15039062,
- 0.26171875,
- 0.28515625,
- 0.13671875,
- 0.27734375,
- 0.28515625,
- 0.12109375,
- 0.26757812,
- 0.265625,
- 0.16210938,
- 0.29882812,
- 0.28515625,
- 0.15625,
- 0.30664062,
- 0.27734375,
- 0.14648438,
- 0.29296875,
- 0.26953125,
- ]
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.16796875,
+ 0.27929688,
+ 0.28320312,
+ 0.11328125,
+ 0.27539062,
+ 0.26171875,
+ 0.10742188,
+ 0.26367188,
+ 0.26171875,
+ 0.1484375,
+ 0.2734375,
+ 0.296875,
+ 0.13476562,
+ 0.2890625,
+ 0.30078125,
+ 0.1171875,
+ 0.28125,
+ 0.28125,
+ 0.16015625,
+ 0.31445312,
+ 0.30078125,
+ 0.15625,
+ 0.32421875,
+ 0.296875,
+ 0.14453125,
+ 0.30859375,
+ 0.2890625,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.17578125,
+ 0.27539062,
+ 0.27734375,
+ 0.11914062,
+ 0.26953125,
+ 0.25390625,
+ 0.109375,
+ 0.25390625,
+ 0.25,
+ 0.15039062,
+ 0.26171875,
+ 0.28515625,
+ 0.13671875,
+ 0.27734375,
+ 0.28515625,
+ 0.12109375,
+ 0.26757812,
+ 0.265625,
+ 0.16210938,
+ 0.29882812,
+ 0.28515625,
+ 0.15625,
+ 0.30664062,
+ 0.27734375,
+ 0.14648438,
+ 0.29296875,
+ 0.26953125,
+ ]
+ ),
+ }
)
+ expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
@@ -310,11 +436,11 @@ class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase
def setUp(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_dummy_inputs(self):
return {
@@ -390,11 +516,11 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
def setUp(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_dummy_inputs(self):
return {
@@ -456,3 +582,187 @@ def test_pipeline_inference(self):
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
+
+
+@require_peft_backend
+@nightly
+@require_big_accelerator
+@require_accelerate
+@require_gguf_version_greater_or_equal("0.10.0")
+class FluxControlLoRAGGUFTests(unittest.TestCase):
+ def test_lora_loading(self):
+ ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
+ transformer = FluxTransformer2DModel.from_single_file(
+ ckpt_path,
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ torch_dtype=torch.bfloat16,
+ )
+ pipe = FluxControlPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
+ pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
+
+ prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
+ control_image = load_image(
+ "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png"
+ )
+
+ output = pipe(
+ prompt=prompt,
+ control_image=control_image,
+ height=256,
+ width=256,
+ num_inference_steps=10,
+ guidance_scale=30.0,
+ output_type="np",
+ generator=torch.manual_seed(0),
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
+
+class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = HiDreamImageTransformer2DModel
+ expected_memory_use_in_gb = 8
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states_t5": torch.randn(
+ (1, 128, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "encoder_hidden_states_llama3": torch.randn(
+ (32, 1, 128, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "pooled_embeds": torch.randn(
+ (1, 2048),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "encoder_hidden_states_image": torch.randn(
+ (1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0)
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanVACETransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states": torch.randn(
+ (1, 96, 2, 64, 64),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states_scale": torch.randn(
+ (8,),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanAnimateGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanAnimateTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states": torch.randn(
+ (1, 96, 2, 64, 64),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states_scale": torch.randn(
+ (8,),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+@require_torch_version_greater("2.7.1")
+class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
+ torch_dtype = torch.bfloat16
+ gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
+
+ @property
+ def quantization_config(self):
+ return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+
+ def _init_pipeline(self, *args, **kwargs):
+ transformer = FluxTransformer2DModel.from_single_file(
+ self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
+ )
+ pipe = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
+ )
+ return pipe
diff --git a/tests/quantization/modelopt/__init__.py b/tests/quantization/modelopt/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/modelopt/test_modelopt.py b/tests/quantization/modelopt/test_modelopt.py
new file mode 100644
index 000000000000..6b0624a28083
--- /dev/null
+++ b/tests/quantization/modelopt/test_modelopt.py
@@ -0,0 +1,306 @@
+import gc
+import tempfile
+import unittest
+
+from diffusers import NVIDIAModelOptConfig, SD3Transformer2DModel, StableDiffusion3Pipeline
+from diffusers.utils import is_nvidia_modelopt_available, is_torch_available
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_reset_peak_memory_stats,
+ enable_full_determinism,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_accelerate,
+ require_big_accelerator,
+ require_modelopt_version_greater_or_equal,
+ require_torch_cuda_compatibility,
+ torch_device,
+)
+
+
+if is_nvidia_modelopt_available():
+ import modelopt.torch.quantization as mtq
+
+if is_torch_available():
+ import torch
+
+ from ..utils import LoRALayer, get_memory_consumption_stat
+
+enable_full_determinism()
+
+
+@nightly
+@require_big_accelerator
+@require_accelerate
+@require_modelopt_version_greater_or_equal("0.33.1")
+class ModelOptBaseTesterMixin:
+ model_id = "hf-internal-testing/tiny-sd3-pipe"
+ model_cls = SD3Transformer2DModel
+ pipeline_cls = StableDiffusion3Pipeline
+ torch_dtype = torch.bfloat16
+ expected_memory_reduction = 0.0
+ keep_in_fp32_module = ""
+ modules_to_not_convert = ""
+ _test_torch_compile = False
+
+ def setUp(self):
+ backend_reset_peak_memory_stats(torch_device)
+ backend_empty_cache(torch_device)
+ gc.collect()
+
+ def tearDown(self):
+ backend_reset_peak_memory_stats(torch_device)
+ backend_empty_cache(torch_device)
+ gc.collect()
+
+ def get_dummy_init_kwargs(self):
+ return {"quant_type": "FP8"}
+
+ def get_dummy_model_init_kwargs(self):
+ return {
+ "pretrained_model_name_or_path": self.model_id,
+ "torch_dtype": self.torch_dtype,
+ "quantization_config": NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()),
+ "subfolder": "transformer",
+ }
+
+ def test_modelopt_layers(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ assert mtq.utils.is_quantized(module)
+
+ def test_modelopt_memory_usage(self):
+ inputs = self.get_dummy_inputs()
+ inputs = {
+ k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
+ }
+
+ unquantized_model = self.model_cls.from_pretrained(
+ self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer"
+ )
+ unquantized_model.to(torch_device)
+ unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
+
+ quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ quantized_model.to(torch_device)
+ quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
+
+ assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
+
+ def test_keep_modules_in_fp32(self):
+ _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
+ self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
+
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ model.to(torch_device)
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ assert module.weight.dtype == torch.float32
+ self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
+
+ def test_modules_to_not_convert(self):
+ init_kwargs = self.get_dummy_model_init_kwargs()
+ quantization_config_kwargs = self.get_dummy_init_kwargs()
+ quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
+ quantization_config = NVIDIAModelOptConfig(**quantization_config_kwargs)
+ init_kwargs.update({"quantization_config": quantization_config})
+
+ model = self.model_cls.from_pretrained(**init_kwargs)
+ model.to(torch_device)
+
+ for name, module in model.named_modules():
+ if name in self.modules_to_not_convert:
+ assert not mtq.utils.is_quantized(module)
+
+ def test_dtype_assignment(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+
+ with self.assertRaises(ValueError):
+ model.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ device_0 = f"{torch_device}:0"
+ model.to(device=device_0, dtype=torch.float16)
+
+ with self.assertRaises(ValueError):
+ model.float()
+
+ with self.assertRaises(ValueError):
+ model.half()
+
+ model.to(torch_device)
+
+ def test_serialization(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ inputs = self.get_dummy_inputs()
+
+ model.to(torch_device)
+ with torch.no_grad():
+ model_output = model(**inputs)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save_pretrained(tmp_dir)
+ saved_model = self.model_cls.from_pretrained(
+ tmp_dir,
+ torch_dtype=torch.bfloat16,
+ )
+
+ saved_model.to(torch_device)
+ with torch.no_grad():
+ saved_model_output = saved_model(**inputs)
+
+ assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
+
+ def test_torch_compile(self):
+ if not self._test_torch_compile:
+ return
+
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
+
+ model.to(torch_device)
+ with torch.no_grad():
+ model_output = model(**self.get_dummy_inputs()).sample
+
+ compiled_model.to(torch_device)
+ with torch.no_grad():
+ compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
+
+ model_output = model_output.detach().float().cpu().numpy()
+ compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
+
+ max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
+ assert max_diff < 1e-3
+
+ def test_device_map_error(self):
+ with self.assertRaises(ValueError):
+ _ = self.model_cls.from_pretrained(
+ **self.get_dummy_model_init_kwargs(),
+ device_map={0: "8GB", "cpu": "16GB"},
+ )
+
+ def get_dummy_inputs(self):
+ batch_size = 1
+ seq_len = 16
+ height = width = 32
+ num_latent_channels = 4
+ caption_channels = 8
+
+ torch.manual_seed(0)
+ hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(
+ torch_device, dtype=torch.bfloat16
+ )
+ encoder_hidden_states = torch.randn((batch_size, seq_len, caption_channels)).to(
+ torch_device, dtype=torch.bfloat16
+ )
+ timestep = torch.tensor([1.0]).to(torch_device, dtype=torch.bfloat16).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ def test_model_cpu_offload(self):
+ init_kwargs = self.get_dummy_init_kwargs()
+ transformer = self.model_cls.from_pretrained(
+ self.model_id,
+ quantization_config=NVIDIAModelOptConfig(**init_kwargs),
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ )
+ pipe = self.pipeline_cls.from_pretrained(self.model_id, transformer=transformer, torch_dtype=torch.bfloat16)
+ pipe.enable_model_cpu_offload(device=torch_device)
+ _ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
+
+ def test_training(self):
+ quantization_config = NVIDIAModelOptConfig(**self.get_dummy_init_kwargs())
+ quantized_model = self.model_cls.from_pretrained(
+ self.model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
+
+ for param in quantized_model.parameters():
+ param.requires_grad = False
+ if param.ndim == 1:
+ param.data = param.data.to(torch.float32)
+
+ for _, module in quantized_model.named_modules():
+ if hasattr(module, "to_q"):
+ module.to_q = LoRALayer(module.to_q, rank=4)
+ if hasattr(module, "to_k"):
+ module.to_k = LoRALayer(module.to_k, rank=4)
+ if hasattr(module, "to_v"):
+ module.to_v = LoRALayer(module.to_v, rank=4)
+
+ with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
+ inputs = self.get_dummy_inputs()
+ output = quantized_model(**inputs)[0]
+ output.norm().backward()
+
+ for module in quantized_model.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+
+
+class SanaTransformerFP8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.6
+
+ def get_dummy_init_kwargs(self):
+ return {"quant_type": "FP8"}
+
+
+class SanaTransformerINT8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.6
+ _test_torch_compile = True
+
+ def get_dummy_init_kwargs(self):
+ return {"quant_type": "INT8"}
+
+
+@require_torch_cuda_compatibility(8.0)
+class SanaTransformerINT4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.55
+
+ def get_dummy_init_kwargs(self):
+ return {
+ "quant_type": "INT4",
+ "block_quantize": 128,
+ "channel_quantize": -1,
+ "disable_conv_quantization": True,
+ }
+
+
+@require_torch_cuda_compatibility(8.0)
+class SanaTransformerNF4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.65
+
+ def get_dummy_init_kwargs(self):
+ return {
+ "quant_type": "NF4",
+ "block_quantize": 128,
+ "channel_quantize": -1,
+ "scale_block_quantize": 8,
+ "scale_channel_quantize": -1,
+ "modules_to_not_convert": ["conv"],
+ }
+
+
+@require_torch_cuda_compatibility(8.0)
+class SanaTransformerNVFP4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.65
+
+ def get_dummy_init_kwargs(self):
+ return {
+ "quant_type": "NVFP4",
+ "block_quantize": 128,
+ "channel_quantize": -1,
+ "scale_block_quantize": 8,
+ "scale_channel_quantize": -1,
+ "modules_to_not_convert": ["conv"],
+ }
diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py
index 9eb6958d2183..e3463f136f94 100644
--- a/tests/quantization/quanto/test_quanto.py
+++ b/tests/quantization/quanto/test_quanto.py
@@ -5,11 +5,15 @@
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from diffusers.models.attention_processor import Attention
from diffusers.utils import is_optimum_quanto_available, is_torch_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ backend_empty_cache,
+ backend_reset_peak_memory_stats,
+ enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
- require_big_gpu_with_torch_cuda,
+ require_accelerator,
require_torch_cuda_compatibility,
torch_device,
)
@@ -23,9 +27,11 @@
from ..utils import LoRALayer, get_memory_consumption_stat
+enable_full_determinism()
+
@nightly
-@require_big_gpu_with_torch_cuda
+@require_accelerator
@require_accelerate
class QuantoBaseTesterMixin:
model_id = None
@@ -39,13 +45,13 @@ class QuantoBaseTesterMixin:
_test_torch_compile = False
def setUp(self):
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.empty_cache()
+ backend_reset_peak_memory_stats(torch_device)
+ backend_empty_cache(torch_device)
gc.collect()
def tearDown(self):
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.empty_cache()
+ backend_reset_peak_memory_stats(torch_device)
+ backend_empty_cache(torch_device)
gc.collect()
def get_dummy_init_kwargs(self):
@@ -89,7 +95,7 @@ def test_keep_modules_in_fp32(self):
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
- model.to("cuda")
+ model.to(torch_device)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
@@ -107,7 +113,7 @@ def test_modules_to_not_convert(self):
init_kwargs.update({"quantization_config": quantization_config})
model = self.model_cls.from_pretrained(**init_kwargs)
- model.to("cuda")
+ model.to(torch_device)
for name, module in model.named_modules():
if name in self.modules_to_not_convert:
@@ -122,7 +128,8 @@ def test_dtype_assignment(self):
with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
- model.to(device="cuda:0", dtype=torch.float16)
+ device_0 = f"{torch_device}:0"
+ model.to(device=device_0, dtype=torch.float16)
with self.assertRaises(ValueError):
# Tries with a cast
@@ -133,7 +140,7 @@ def test_dtype_assignment(self):
model.half()
# This should work
- model.to("cuda")
+ model.to(torch_device)
def test_serialization(self):
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py
new file mode 100644
index 000000000000..5f1a3de2e579
--- /dev/null
+++ b/tests/quantization/test_pipeline_level_quantization.py
@@ -0,0 +1,317 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a clone of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import tempfile
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from diffusers import BitsAndBytesConfig, DiffusionPipeline, QuantoConfig
+from diffusers.quantizers import PipelineQuantizationConfig
+from diffusers.utils import logging
+
+from ..testing_utils import (
+ CaptureLogger,
+ is_transformers_available,
+ require_accelerate,
+ require_bitsandbytes_version_greater,
+ require_quanto,
+ require_torch,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+
+if is_transformers_available():
+ from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig
+else:
+ TranBitsAndBytesConfig = None
+
+
+@require_bitsandbytes_version_greater("0.43.2")
+@require_quanto
+@require_accelerate
+@require_torch
+@require_torch_accelerator
+@slow
+class PipelineQuantizationTests(unittest.TestCase):
+ model_name = "hf-internal-testing/tiny-flux-pipe"
+ prompt = "a beautiful sunset amidst the mountains."
+ num_inference_steps = 10
+ seed = 0
+
+ def test_quant_config_set_correctly_through_kwargs(self):
+ components_to_quantize = ["transformer", "text_encoder_2"]
+ quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={
+ "load_in_4bit": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": torch.bfloat16,
+ },
+ components_to_quantize=components_to_quantize,
+ )
+ pipe = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
+ for name, component in pipe.components.items():
+ if name in components_to_quantize:
+ self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
+ quantization_config = component.config.quantization_config
+ self.assertTrue(quantization_config.load_in_4bit)
+ self.assertTrue(quantization_config.quant_method == "bitsandbytes")
+
+ _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
+
+ def test_quant_config_set_correctly_through_granular(self):
+ quant_config = PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": QuantoConfig(weights_dtype="int8"),
+ "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
+ }
+ )
+ components_to_quantize = list(quant_config.quant_mapping.keys())
+ pipe = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
+ for name, component in pipe.components.items():
+ if name in components_to_quantize:
+ self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
+ quantization_config = component.config.quantization_config
+
+ if name == "text_encoder_2":
+ self.assertTrue(quantization_config.load_in_4bit)
+ self.assertTrue(quantization_config.quant_method == "bitsandbytes")
+ else:
+ self.assertTrue(quantization_config.quant_method == "quanto")
+
+ _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps)
+
+ def test_raises_error_for_invalid_config(self):
+ with self.assertRaises(ValueError) as err_context:
+ _ = PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": QuantoConfig(weights_dtype="int8"),
+ "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
+ },
+ quant_backend="bitsandbytes_4bit",
+ )
+
+ self.assertTrue(
+ str(err_context.exception)
+ == "Both `quant_backend` and `quant_mapping` cannot be specified at the same time."
+ )
+
+ def test_validation_for_kwargs(self):
+ components_to_quantize = ["transformer", "text_encoder_2"]
+ with self.assertRaises(ValueError) as err_context:
+ _ = PipelineQuantizationConfig(
+ quant_backend="quanto",
+ quant_kwargs={"weights_dtype": "int8"},
+ components_to_quantize=components_to_quantize,
+ )
+
+ self.assertTrue(
+ "The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception)
+ )
+
+ def test_raises_error_for_wrong_config_class(self):
+ quant_config = {
+ "transformer": QuantoConfig(weights_dtype="int8"),
+ "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
+ }
+ with self.assertRaises(ValueError) as err_context:
+ _ = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ )
+ self.assertTrue(
+ str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`."
+ )
+
+ def test_validation_for_mapping(self):
+ with self.assertRaises(ValueError) as err_context:
+ _ = PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": DiffusionPipeline(),
+ "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
+ }
+ )
+
+ self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception))
+
+ def test_saving_loading(self):
+ quant_config = PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": QuantoConfig(weights_dtype="int8"),
+ "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16),
+ }
+ )
+ components_to_quantize = list(quant_config.quant_mapping.keys())
+ pipe = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
+
+ pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"}
+ output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device)
+ for name, component in loaded_pipe.components.items():
+ if name in components_to_quantize:
+ self.assertTrue(getattr(component.config, "quantization_config", None) is not None)
+ quantization_config = component.config.quantization_config
+
+ if name == "text_encoder_2":
+ self.assertTrue(quantization_config.load_in_4bit)
+ self.assertTrue(quantization_config.quant_method == "bitsandbytes")
+ else:
+ self.assertTrue(quantization_config.quant_method == "quanto")
+
+ output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images
+
+ self.assertTrue(torch.allclose(output_1, output_2))
+
+ @parameterized.expand(["quant_kwargs", "quant_mapping"])
+ def test_warn_invalid_component(self, method):
+ invalid_component = "foo"
+ if method == "quant_kwargs":
+ components_to_quantize = ["transformer", invalid_component]
+ quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=components_to_quantize,
+ )
+ else:
+ quant_config = PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": QuantoConfig("int8"),
+ invalid_component: TranBitsAndBytesConfig(load_in_8bit=True),
+ }
+ )
+
+ logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils")
+ logger.setLevel(logging.WARNING)
+ with CaptureLogger(logger) as cap_logger:
+ _ = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ )
+ self.assertTrue(invalid_component in cap_logger.out)
+
+ @parameterized.expand(["quant_kwargs", "quant_mapping"])
+ def test_no_quantization_for_all_invalid_components(self, method):
+ invalid_component = "foo"
+ if method == "quant_kwargs":
+ components_to_quantize = [invalid_component]
+ quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=components_to_quantize,
+ )
+ else:
+ quant_config = PipelineQuantizationConfig(
+ quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)}
+ )
+
+ pipe = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ )
+ for name, component in pipe.components.items():
+ if isinstance(component, torch.nn.Module):
+ self.assertTrue(not hasattr(component.config, "quantization_config"))
+
+ @parameterized.expand(["quant_kwargs", "quant_mapping"])
+ def test_quant_config_repr(self, method):
+ component_name = "transformer"
+ if method == "quant_kwargs":
+ components_to_quantize = [component_name]
+ quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=components_to_quantize,
+ )
+ else:
+ quant_config = PipelineQuantizationConfig(
+ quant_mapping={component_name: BitsAndBytesConfig(load_in_8bit=True)}
+ )
+
+ pipe = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ )
+ self.assertTrue(getattr(pipe, "quantization_config", None) is not None)
+ retrieved_config = pipe.quantization_config
+ expected_config = """
+transformer BitsAndBytesConfig {
+ "_load_in_4bit": false,
+ "_load_in_8bit": true,
+ "bnb_4bit_compute_dtype": "float32",
+ "bnb_4bit_quant_storage": "uint8",
+ "bnb_4bit_quant_type": "fp4",
+ "bnb_4bit_use_double_quant": false,
+ "llm_int8_enable_fp32_cpu_offload": false,
+ "llm_int8_has_fp16_weight": false,
+ "llm_int8_skip_modules": null,
+ "llm_int8_threshold": 6.0,
+ "load_in_4bit": false,
+ "load_in_8bit": true,
+ "quant_method": "bitsandbytes"
+}
+
+"""
+ expected_data = self._parse_config_string(expected_config)
+ actual_data = self._parse_config_string(str(retrieved_config))
+ self.assertTrue(actual_data == expected_data)
+
+ def _parse_config_string(self, config_string: str) -> tuple[str, dict]:
+ first_brace = config_string.find("{")
+ if first_brace == -1:
+ raise ValueError("Could not find opening brace '{' in the string.")
+
+ json_part = config_string[first_brace:]
+ data = json.loads(json_part)
+
+ return data
+
+ def test_single_component_to_quantize(self):
+ component_to_quantize = "transformer"
+ quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=component_to_quantize,
+ )
+ pipe = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ )
+ for name, component in pipe.components.items():
+ if name == component_to_quantize:
+ self.assertTrue(hasattr(component.config, "quantization_config"))
diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py
new file mode 100644
index 000000000000..29758cbdd735
--- /dev/null
+++ b/tests/quantization/test_torch_compile_utils.py
@@ -0,0 +1,106 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Team Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a clone of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import gc
+import inspect
+
+import torch
+
+from diffusers import DiffusionPipeline
+
+from ..testing_utils import backend_empty_cache, require_torch_accelerator, slow, torch_device
+
+
+@require_torch_accelerator
+@slow
+class QuantCompileTests:
+ @property
+ def quantization_config(self):
+ raise NotImplementedError(
+ "This property should be implemented in the subclass to return the appropriate quantization config."
+ )
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+ torch.compiler.reset()
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+ torch.compiler.reset()
+
+ def _init_pipeline(self, quantization_config, torch_dtype):
+ pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ quantization_config=quantization_config,
+ torch_dtype=torch_dtype,
+ )
+ return pipe
+
+ def _test_torch_compile(self, torch_dtype=torch.bfloat16):
+ pipe = self._init_pipeline(self.quantization_config, torch_dtype).to(torch_device)
+ # `fullgraph=True` ensures no graph breaks
+ pipe.transformer.compile(fullgraph=True)
+
+ # small resolutions to ensure speedy execution.
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
+
+ def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
+ pipe = self._init_pipeline(self.quantization_config, torch_dtype)
+ pipe.enable_model_cpu_offload()
+ # regional compilation is better for offloading.
+ # see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
+ if getattr(pipe.transformer, "_repeated_blocks"):
+ pipe.transformer.compile_repeated_blocks(fullgraph=True)
+ else:
+ pipe.transformer.compile()
+
+ # small resolutions to ensure speedy execution.
+ pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
+
+ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
+ torch._dynamo.config.cache_size_limit = 1000
+
+ pipe = self._init_pipeline(self.quantization_config, torch_dtype)
+ group_offload_kwargs = {
+ "onload_device": torch.device(torch_device),
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": use_stream,
+ }
+ pipe.transformer.enable_group_offload(**group_offload_kwargs)
+ pipe.transformer.compile()
+ for name, component in pipe.components.items():
+ if name != "transformer" and isinstance(component, torch.nn.Module):
+ if torch.device(component.device).type == "cpu":
+ component.to(torch_device)
+
+ # small resolutions to ensure speedy execution.
+ pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
+
+ def test_torch_compile(self):
+ self._test_torch_compile()
+
+ def test_torch_compile_with_cpu_offload(self):
+ self._test_torch_compile_with_cpu_offload()
+
+ def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
+ for cls in inspect.getmro(self.__class__):
+ if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
+ return
+ self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
diff --git a/tests/quantization/torchao/README.md b/tests/quantization/torchao/README.md
index fadc529e12fc..373593091ac0 100644
--- a/tests/quantization/torchao/README.md
+++ b/tests/quantization/torchao/README.md
@@ -29,7 +29,7 @@ The benchmark results for Flux and CogVideoX can be found in [this](https://gith
The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent:
```bash
-HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
+HF_XET_HIGH_PERFORMANCE=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
```
`diffusers-cli`:
diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py
index 0e671307dd18..38997de17b12 100644
--- a/tests/quantization/torchao/test_torchao.py
+++ b/tests/quantization/torchao/test_torchao.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,11 +14,14 @@
# limitations under the License.
import gc
+import importlib.metadata
import tempfile
import unittest
from typing import List
import numpy as np
+from packaging import version
+from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
@@ -29,18 +32,23 @@
TorchAoConfig,
)
from diffusers.models.attention_processor import Attention
-from diffusers.utils.testing_utils import (
+from diffusers.quantizers import PipelineQuantizationConfig
+
+from ...testing_utils import (
+ backend_empty_cache,
+ backend_synchronize,
enable_full_determinism,
is_torch_available,
is_torchao_available,
nightly,
numpy_cosine_similarity_distance,
require_torch,
- require_torch_gpu,
+ require_torch_accelerator,
require_torchao_version_greater_or_equal,
slow,
torch_device,
)
+from ..test_torch_compile_utils import QuantCompileTests
enable_full_determinism()
@@ -59,9 +67,12 @@
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
+ if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
+ from torchao.quantization import Int8WeightOnlyConfig
+
@require_torch
-@require_torch_gpu
+@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
@@ -79,7 +90,7 @@ def test_post_init_check(self):
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
- with self.assertRaisesRegex(ValueError, "is not supported yet"):
+ with self.assertRaisesRegex(ValueError, "is not supported"):
_ = TorchAoConfig("uint8")
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
@@ -119,12 +130,12 @@ def test_repr(self):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
-@require_torch_gpu
+@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_dummy_components(
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
@@ -230,7 +241,7 @@ def test_quantization(self):
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
- if TorchAoConfig._is_cuda_capability_atleast_8_9():
+ if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
@@ -269,6 +280,7 @@ def test_int4wo_quant_bfloat16_conversion(self):
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
+ device_map=f"{torch_device}:0",
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
@@ -338,7 +350,7 @@ def test_device_map(self):
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
- self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
@@ -359,7 +371,7 @@ def test_device_map(self):
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
- self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
+ self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
@@ -515,17 +527,26 @@ def test_sequential_cpu_offload(self):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
+ @require_torchao_version_greater_or_equal("0.9.0")
+ def test_aobase_config(self):
+ quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
+ components = self.get_dummy_components(quantization_config)
+ pipe = FluxPipeline(**components).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ _ = pipe(**inputs)
+
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
-@require_torch_gpu
+@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
@@ -593,17 +614,17 @@ def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs,
)
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
- def test_int_a8w8_cuda(self):
+ def test_int_a8w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
- device = "cuda"
+ device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
- def test_int_a16w8_cuda(self):
+ def test_int_a16w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
- device = "cuda"
+ device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@@ -621,17 +642,69 @@ def test_int_a16w8_cpu(self):
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+ @require_torchao_version_greater_or_equal("0.9.0")
+ def test_aobase_config(self):
+ quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
+ expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
+ device = torch_device
+ self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
+ self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+
+
+@require_torchao_version_greater_or_equal("0.7.0")
+class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
+ @property
+ def quantization_config(self):
+ return PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": TorchAoConfig(quant_type="int8_weight_only"),
+ },
+ )
+
+ @unittest.skip(
+ "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
+ "when compiling."
+ )
+ def test_torch_compile_with_cpu_offload(self):
+ # RuntimeError: _apply(): Couldn't swap Linear.weight
+ super().test_torch_compile_with_cpu_offload()
+
+ @parameterized.expand([False, True])
+ @unittest.skip(
+ """
+ For `use_stream=False`:
+ - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
+ is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
+ For `use_stream=True`:
+ Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
+ """
+ )
+ def test_torch_compile_with_group_offload_leaf(self, use_stream):
+ # For use_stream=False:
+ # If we run group offloading without compilation, we will see:
+ # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
+ # When running with compilation, the error ends up being different:
+ # Dynamo failed to run FX node with fake tensors: call_function (*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
+ # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
+ # Looks like something that will have to be looked into upstream.
+ # for linear layers, weight.tensor_impl shows cuda... but:
+ # weight.tensor_impl.{data,scale,zero_point}.device will be cpu
+
+ # For use_stream=True:
+ # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=, types=(,), arg_types=(,), kwarg_types={}
+ super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
+
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
-@require_torch_gpu
+@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_dummy_components(self, quantization_config: TorchAoConfig):
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
@@ -702,7 +775,7 @@ def test_quantization(self):
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
]
- if TorchAoConfig._is_cuda_capability_atleast_8_9():
+ if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
@@ -713,8 +786,8 @@ def test_quantization(self):
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice)
gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
+ backend_empty_cache(torch_device)
+ backend_synchronize(torch_device)
def test_serialization_int8wo(self):
quantization_config = TorchAoConfig("int8wo")
@@ -733,8 +806,8 @@ def test_serialization_int8wo(self):
pipe.remove_all_hooks()
del pipe.transformer
gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
+ backend_empty_cache(torch_device)
+ backend_synchronize(torch_device)
transformer = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
)
@@ -783,14 +856,14 @@ def test_memory_footprint_int8wo(self):
@require_torch
-@require_torch_gpu
+@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
def tearDown(self):
gc.collect()
- torch.cuda.empty_cache()
+ backend_empty_cache(torch_device)
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
if str(device).startswith("mps"):
diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py
index 04ebf9e159f4..a74ece5a3a3a 100644
--- a/tests/quantization/utils.py
+++ b/tests/quantization/utils.py
@@ -1,5 +1,12 @@
from diffusers.utils import is_torch_available
+from ..testing_utils import (
+ backend_empty_cache,
+ backend_max_memory_allocated,
+ backend_reset_peak_memory_stats,
+ torch_device,
+)
+
if is_torch_available():
import torch
@@ -30,9 +37,9 @@ def forward(self, input, *args, **kwargs):
@torch.no_grad()
@torch.inference_mode()
def get_memory_consumption_stat(model, inputs):
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.empty_cache()
+ backend_reset_peak_memory_stats(torch_device)
+ backend_empty_cache(torch_device)
model(**inputs)
- max_memory_mem_allocated = torch.cuda.max_memory_allocated()
- return max_memory_mem_allocated
+ max_mem_allocated = backend_max_memory_allocated(torch_device)
+ return max_mem_allocated
diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py
index cec96e729a48..27170cba0835 100644
--- a/tests/remote/test_remote_decode.py
+++ b/tests/remote/test_remote_decode.py
@@ -30,13 +30,14 @@
from diffusers.utils.remote_utils import (
remote_decode,
)
-from diffusers.utils.testing_utils import (
+from diffusers.video_processor import VideoProcessor
+
+from ..testing_utils import (
enable_full_determinism,
slow,
torch_all_close,
torch_device,
)
-from diffusers.video_processor import VideoProcessor
enable_full_determinism()
diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py
index 62ed97ee8f49..4c0daf08fd8c 100644
--- a/tests/remote/test_remote_encode.py
+++ b/tests/remote/test_remote_encode.py
@@ -31,7 +31,8 @@
remote_decode,
remote_encode,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
enable_full_determinism,
slow,
)
diff --git a/tests/schedulers/test_scheduler_ddim_parallel.py b/tests/schedulers/test_scheduler_ddim_parallel.py
index 5434d08b5628..3ce8034cfb95 100644
--- a/tests/schedulers/test_scheduler_ddim_parallel.py
+++ b/tests/schedulers/test_scheduler_ddim_parallel.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/schedulers/test_scheduler_ddpm_parallel.py b/tests/schedulers/test_scheduler_ddpm_parallel.py
index c358ad991c1d..377067071c25 100644
--- a/tests/schedulers/test_scheduler_ddpm_parallel.py
+++ b/tests/schedulers/test_scheduler_ddpm_parallel.py
@@ -1,4 +1,4 @@
-# Copyright 2024 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 ParaDiGMS authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py
index 55b3202ad0be..28c354709dc9 100644
--- a/tests/schedulers/test_scheduler_dpm_multi.py
+++ b/tests/schedulers/test_scheduler_dpm_multi.py
@@ -357,9 +357,9 @@ def test_custom_timesteps(self):
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for algorithm_type: {algorithm_type}, prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py
index 69b611173423..e4dde67344ac 100644
--- a/tests/schedulers/test_scheduler_dpm_sde.py
+++ b/tests/schedulers/test_scheduler_dpm_sde.py
@@ -1,8 +1,8 @@
import torch
from diffusers import DPMSolverSDEScheduler
-from diffusers.utils.testing_utils import require_torchsde, torch_device
+from ..testing_utils import require_torchsde, torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py
index 7cbaa5cc5e8d..0756a5ed71ff 100644
--- a/tests/schedulers/test_scheduler_dpm_single.py
+++ b/tests/schedulers/test_scheduler_dpm_single.py
@@ -345,9 +345,9 @@ def test_custom_timesteps(self):
lower_order_final=lower_order_final,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, lower_order_final: {lower_order_final} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
index e97d64ec5f1d..8525ce61c40d 100644
--- a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
+++ b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py
@@ -188,9 +188,9 @@ def test_solver_order_and_type(self):
prediction_type=prediction_type,
algorithm_type=algorithm_type,
)
- assert (
- not torch.isnan(sample).any()
- ), f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
+ assert not torch.isnan(sample).any(), (
+ f"Samples have nan numbers, {order}, {solver_type}, {prediction_type}, {algorithm_type}"
+ )
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py
index 4c7e02442cd0..ee99465abfc3 100644
--- a/tests/schedulers/test_scheduler_euler.py
+++ b/tests/schedulers/test_scheduler_euler.py
@@ -1,8 +1,8 @@
import torch
from diffusers import EulerDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
@@ -245,9 +245,9 @@ def test_custom_timesteps(self):
interpolation_type=interpolation_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, interpolation_type: {interpolation_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_custom_sigmas(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
@@ -260,9 +260,9 @@ def test_custom_sigmas(self):
prediction_type=prediction_type,
final_sigmas_type=final_sigmas_type,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type} and final_sigmas_type: {final_sigmas_type}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_euler_ancestral.py b/tests/schedulers/test_scheduler_euler_ancestral.py
index 9f22ab38ddaf..c4fe61bfc387 100644
--- a/tests/schedulers/test_scheduler_euler_ancestral.py
+++ b/tests/schedulers/test_scheduler_euler_ancestral.py
@@ -1,8 +1,8 @@
import torch
from diffusers import EulerAncestralDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py
deleted file mode 100644
index 8ccb5f6594a5..000000000000
--- a/tests/schedulers/test_scheduler_flax.py
+++ /dev/null
@@ -1,920 +0,0 @@
-# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import inspect
-import tempfile
-import unittest
-from typing import Dict, List, Tuple
-
-from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import require_flax
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from jax import random
-
- jax_device = jax.default_backend()
-
-
-@require_flax
-class FlaxSchedulerCommonTest(unittest.TestCase):
- scheduler_classes = ()
- forward_default_kwargs = ()
-
- @property
- def dummy_sample(self):
- batch_size = 4
- num_channels = 3
- height = 8
- width = 8
-
- key1, key2 = random.split(random.PRNGKey(0))
- sample = random.uniform(key1, (batch_size, num_channels, height, width))
-
- return sample, key2
-
- @property
- def dummy_sample_deter(self):
- batch_size = 4
- num_channels = 3
- height = 8
- width = 8
-
- num_elems = batch_size * num_channels * height * width
- sample = jnp.arange(num_elems)
- sample = sample.reshape(num_channels, height, width, batch_size)
- sample = sample / num_elems
- return jnp.transpose(sample, (3, 0, 1, 2))
-
- def get_scheduler_config(self):
- raise NotImplementedError
-
- def dummy_model(self):
- def model(sample, t, *args):
- return sample * t / (t + 1)
-
- return model
-
- def check_over_configs(self, time_step=0, **config):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def check_over_forward(self, time_step=0, **forward_kwargs):
- kwargs = dict(self.forward_default_kwargs)
- kwargs.update(forward_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def test_from_save_pretrained(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def test_step_shape(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample
- output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
-
- self.assertEqual(output_0.shape, sample.shape)
- self.assertEqual(output_0.shape, output_1.shape)
-
- def test_scheduler_outputs_equivalence(self):
- def set_nan_tensor_to_zero(t):
- return t.at[t != t].set(0)
-
- def recursive_check(tuple_object, dict_object):
- if isinstance(tuple_object, (List, Tuple)):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif isinstance(tuple_object, Dict):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif tuple_object is None:
- return
- else:
- self.assertTrue(
- jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
- msg=(
- "Tuple and dict output are not equal. Difference:"
- f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
- f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
- f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
- ),
- )
-
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs)
-
- recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
-
- def test_deprecated_kwargs(self):
- for scheduler_class in self.scheduler_classes:
- has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
- has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
-
- if has_kwarg_in_model_class and not has_deprecated_kwarg:
- raise ValueError(
- f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
- " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
- " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
- " []`"
- )
-
- if not has_kwarg_in_model_class and has_deprecated_kwarg:
- raise ValueError(
- f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
- " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
- f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
- " deprecated argument from `_deprecated_kwargs = []`"
- )
-
-
-@require_flax
-class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
- scheduler_classes = (FlaxDDPMScheduler,)
-
- def get_scheduler_config(self, **kwargs):
- config = {
- "num_train_timesteps": 1000,
- "beta_start": 0.0001,
- "beta_end": 0.02,
- "beta_schedule": "linear",
- "variance_type": "fixed_small",
- "clip_sample": True,
- }
-
- config.update(**kwargs)
- return config
-
- def test_timesteps(self):
- for timesteps in [1, 5, 100, 1000]:
- self.check_over_configs(num_train_timesteps=timesteps)
-
- def test_betas(self):
- for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
- self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
-
- def test_schedules(self):
- for schedule in ["linear", "squaredcos_cap_v2"]:
- self.check_over_configs(beta_schedule=schedule)
-
- def test_variance_type(self):
- for variance in ["fixed_small", "fixed_large", "other"]:
- self.check_over_configs(variance_type=variance)
-
- def test_clip_sample(self):
- for clip_sample in [True, False]:
- self.check_over_configs(clip_sample=clip_sample)
-
- def test_time_indices(self):
- for t in [0, 500, 999]:
- self.check_over_forward(time_step=t)
-
- def test_variance(self):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5
-
- def test_full_loop_no_noise(self):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- num_trained_timesteps = len(scheduler)
-
- model = self.dummy_model()
- sample = self.dummy_sample_deter
- key1, key2 = random.split(random.PRNGKey(0))
-
- for t in reversed(range(num_trained_timesteps)):
- # 1. predict noise residual
- residual = model(sample, t)
-
- # 2. predict previous mean of sample x_t-1
- output = scheduler.step(state, residual, t, sample, key1)
- pred_prev_sample = output.prev_sample
- state = output.state
- key1, key2 = random.split(key2)
-
- # if t > 0:
- # noise = self.dummy_sample_deter
- # variance = scheduler.get_variance(t) ** (0.5) * noise
- #
- # sample = pred_prev_sample + variance
- sample = pred_prev_sample
-
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 255.0714) < 1e-2
- assert abs(result_mean - 0.332124) < 1e-3
- else:
- assert abs(result_sum - 270.2) < 1e-1
- assert abs(result_mean - 0.3519494) < 1e-3
-
-
-@require_flax
-class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
- scheduler_classes = (FlaxDDIMScheduler,)
- forward_default_kwargs = (("num_inference_steps", 50),)
-
- def get_scheduler_config(self, **kwargs):
- config = {
- "num_train_timesteps": 1000,
- "beta_start": 0.0001,
- "beta_end": 0.02,
- "beta_schedule": "linear",
- }
-
- config.update(**kwargs)
- return config
-
- def full_loop(self, **config):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- key1, key2 = random.split(random.PRNGKey(0))
-
- num_inference_steps = 10
-
- model = self.dummy_model()
- sample = self.dummy_sample_deter
-
- state = scheduler.set_timesteps(state, num_inference_steps)
-
- for t in state.timesteps:
- residual = model(sample, t)
- output = scheduler.step(state, residual, t, sample)
- sample = output.prev_sample
- state = output.state
- key1, key2 = random.split(key2)
-
- return sample
-
- def check_over_configs(self, time_step=0, **config):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def test_from_save_pretrained(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def check_over_forward(self, time_step=0, **forward_kwargs):
- kwargs = dict(self.forward_default_kwargs)
- kwargs.update(forward_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def test_scheduler_outputs_equivalence(self):
- def set_nan_tensor_to_zero(t):
- return t.at[t != t].set(0)
-
- def recursive_check(tuple_object, dict_object):
- if isinstance(tuple_object, (List, Tuple)):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif isinstance(tuple_object, Dict):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif tuple_object is None:
- return
- else:
- self.assertTrue(
- jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
- msg=(
- "Tuple and dict output are not equal. Difference:"
- f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
- f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
- f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
- ),
- )
-
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
-
- recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
-
- def test_step_shape(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample
- output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
-
- self.assertEqual(output_0.shape, sample.shape)
- self.assertEqual(output_0.shape, output_1.shape)
-
- def test_timesteps(self):
- for timesteps in [100, 500, 1000]:
- self.check_over_configs(num_train_timesteps=timesteps)
-
- def test_steps_offset(self):
- for steps_offset in [0, 1]:
- self.check_over_configs(steps_offset=steps_offset)
-
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(steps_offset=1)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- state = scheduler.set_timesteps(state, 5)
- assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all()
-
- def test_betas(self):
- for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
- self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
-
- def test_schedules(self):
- for schedule in ["linear", "squaredcos_cap_v2"]:
- self.check_over_configs(beta_schedule=schedule)
-
- def test_time_indices(self):
- for t in [1, 10, 49]:
- self.check_over_forward(time_step=t)
-
- def test_inference_steps(self):
- for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
- self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
-
- def test_variance(self):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5
-
- def test_full_loop_no_noise(self):
- sample = self.full_loop()
-
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- assert abs(result_sum - 172.0067) < 1e-2
- assert abs(result_mean - 0.223967) < 1e-3
-
- def test_full_loop_with_set_alpha_to_one(self):
- # We specify different beta, so that the first alpha is 0.99
- sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 149.8409) < 1e-2
- assert abs(result_mean - 0.1951) < 1e-3
- else:
- assert abs(result_sum - 149.8295) < 1e-2
- assert abs(result_mean - 0.1951) < 1e-3
-
- def test_full_loop_with_no_set_alpha_to_one(self):
- # We specify different beta, so that the first alpha is 0.99
- sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- pass
- # FIXME: both result_sum and result_mean are nan on TPU
- # assert jnp.isnan(result_sum)
- # assert jnp.isnan(result_mean)
- else:
- assert abs(result_sum - 149.0784) < 1e-2
- assert abs(result_mean - 0.1941) < 1e-3
-
- def test_prediction_type(self):
- for prediction_type in ["epsilon", "sample", "v_prediction"]:
- self.check_over_configs(prediction_type=prediction_type)
-
-
-@require_flax
-class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
- scheduler_classes = (FlaxPNDMScheduler,)
- forward_default_kwargs = (("num_inference_steps", 50),)
-
- def get_scheduler_config(self, **kwargs):
- config = {
- "num_train_timesteps": 1000,
- "beta_start": 0.0001,
- "beta_end": 0.02,
- "beta_schedule": "linear",
- }
-
- config.update(**kwargs)
- return config
-
- def check_over_configs(self, time_step=0, **config):
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
- dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
- # copy over dummy past residuals
- state = state.replace(ets=dummy_past_residuals[:])
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
- # copy over dummy past residuals
- new_state = new_state.replace(ets=dummy_past_residuals[:])
-
- (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
- (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
-
- assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical"
-
- output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
- new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- @unittest.skip("Test not supported.")
- def test_from_save_pretrained(self):
- pass
-
- def test_scheduler_outputs_equivalence(self):
- def set_nan_tensor_to_zero(t):
- return t.at[t != t].set(0)
-
- def recursive_check(tuple_object, dict_object):
- if isinstance(tuple_object, (List, Tuple)):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif isinstance(tuple_object, Dict):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif tuple_object is None:
- return
- else:
- self.assertTrue(
- jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
- msg=(
- "Tuple and dict output are not equal. Difference:"
- f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
- f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
- f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
- ),
- )
-
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
-
- recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
-
- def check_over_forward(self, time_step=0, **forward_kwargs):
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
- dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
-
- # copy over dummy past residuals (must be after setting timesteps)
- scheduler.ets = dummy_past_residuals[:]
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
- # copy over dummy past residuals
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
-
- # copy over dummy past residual (must be after setting timesteps)
- new_state.replace(ets=dummy_past_residuals[:])
-
- output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
- new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
- new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def full_loop(self, **config):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- num_inference_steps = 10
- model = self.dummy_model()
- sample = self.dummy_sample_deter
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
-
- for i, t in enumerate(state.prk_timesteps):
- residual = model(sample, t)
- sample, state = scheduler.step_prk(state, residual, t, sample)
-
- for i, t in enumerate(state.plms_timesteps):
- residual = model(sample, t)
- sample, state = scheduler.step_plms(state, residual, t, sample)
-
- return sample
-
- def test_step_shape(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- # copy over dummy past residuals (must be done after set_timesteps)
- dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
- state = state.replace(ets=dummy_past_residuals[:])
-
- output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs)
- output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs)
-
- self.assertEqual(output_0.shape, sample.shape)
- self.assertEqual(output_0.shape, output_1.shape)
-
- output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs)
- output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs)
-
- self.assertEqual(output_0.shape, sample.shape)
- self.assertEqual(output_0.shape, output_1.shape)
-
- def test_timesteps(self):
- for timesteps in [100, 1000]:
- self.check_over_configs(num_train_timesteps=timesteps)
-
- def test_steps_offset(self):
- for steps_offset in [0, 1]:
- self.check_over_configs(steps_offset=steps_offset)
-
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(steps_offset=1)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- state = scheduler.set_timesteps(state, 10, shape=())
- assert jnp.equal(
- state.timesteps,
- jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
- ).all()
-
- def test_betas(self):
- for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
- self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
-
- def test_schedules(self):
- for schedule in ["linear", "squaredcos_cap_v2"]:
- self.check_over_configs(beta_schedule=schedule)
-
- def test_time_indices(self):
- for t in [1, 5, 10]:
- self.check_over_forward(time_step=t)
-
- def test_inference_steps(self):
- for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
- self.check_over_forward(num_inference_steps=num_inference_steps)
-
- def test_pow_of_3_inference_steps(self):
- # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
- num_inference_steps = 27
-
- for scheduler_class in self.scheduler_classes:
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
-
- # before power of 3 fix, would error on first step, so we only need to do two
- for i, t in enumerate(state.prk_timesteps[:2]):
- sample, state = scheduler.step_prk(state, residual, t, sample)
-
- def test_inference_plms_no_past_residuals(self):
- with self.assertRaises(ValueError):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample
-
- def test_full_loop_no_noise(self):
- sample = self.full_loop()
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 198.1275) < 1e-2
- assert abs(result_mean - 0.2580) < 1e-3
- else:
- assert abs(result_sum - 198.1318) < 1e-2
- assert abs(result_mean - 0.2580) < 1e-3
-
- def test_full_loop_with_set_alpha_to_one(self):
- # We specify different beta, so that the first alpha is 0.99
- sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 186.83226) < 1e-2
- assert abs(result_mean - 0.24327) < 1e-3
- else:
- assert abs(result_sum - 186.9466) < 1e-2
- assert abs(result_mean - 0.24342) < 1e-3
-
- def test_full_loop_with_no_set_alpha_to_one(self):
- # We specify different beta, so that the first alpha is 0.99
- sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 186.83226) < 1e-2
- assert abs(result_mean - 0.24327) < 1e-3
- else:
- assert abs(result_sum - 186.9482) < 1e-2
- assert abs(result_mean - 0.2434) < 1e-3
diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py
index 9e060c6d476f..97bef50048ba 100644
--- a/tests/schedulers/test_scheduler_heun.py
+++ b/tests/schedulers/test_scheduler_heun.py
@@ -1,8 +1,8 @@
import torch
from diffusers import HeunDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
@@ -216,9 +216,9 @@ def test_custom_timesteps(self):
prediction_type=prediction_type,
timestep_spacing=timestep_spacing,
)
- assert (
- torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5
- ), f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
+ assert torch.sum(torch.abs(sample - sample_custom_timesteps)) < 1e-5, (
+ f"Scheduler outputs are not identical for prediction_type: {prediction_type}, timestep_spacing: {timestep_spacing}"
+ )
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/schedulers/test_scheduler_kdpm2_ancestral.py b/tests/schedulers/test_scheduler_kdpm2_ancestral.py
index fa85c2be45ed..135534db4536 100644
--- a/tests/schedulers/test_scheduler_kdpm2_ancestral.py
+++ b/tests/schedulers/test_scheduler_kdpm2_ancestral.py
@@ -1,8 +1,8 @@
import torch
from diffusers import KDPM2AncestralDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_kdpm2_discrete.py b/tests/schedulers/test_scheduler_kdpm2_discrete.py
index 4d8923b6946b..370ba2253ee2 100644
--- a/tests/schedulers/test_scheduler_kdpm2_discrete.py
+++ b/tests/schedulers/test_scheduler_kdpm2_discrete.py
@@ -1,8 +1,8 @@
import torch
from diffusers import KDPM2DiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py
index f3f6e9ba5837..f54970e0eba3 100644
--- a/tests/schedulers/test_scheduler_lcm.py
+++ b/tests/schedulers/test_scheduler_lcm.py
@@ -4,8 +4,8 @@
import torch
from diffusers import LCMScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py
index 3bfcd57c1b6d..c4abca3ac973 100644
--- a/tests/schedulers/test_scheduler_lms.py
+++ b/tests/schedulers/test_scheduler_lms.py
@@ -1,8 +1,8 @@
import torch
from diffusers import LMSDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_sasolver.py b/tests/schedulers/test_scheduler_sasolver.py
index baa2736b2fcc..2c2d2c0397bb 100644
--- a/tests/schedulers/test_scheduler_sasolver.py
+++ b/tests/schedulers/test_scheduler_sasolver.py
@@ -1,8 +1,8 @@
import torch
from diffusers import SASolverScheduler
-from diffusers.utils.testing_utils import require_torchsde, torch_device
+from ..testing_utils import require_torchsde, torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py
index 42ca1bc54155..5a8380e659fc 100755
--- a/tests/schedulers/test_schedulers.py
+++ b/tests/schedulers/test_schedulers.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -41,9 +41,9 @@
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import logging
-from diffusers.utils.testing_utils import CaptureLogger, torch_device
from ..others.test_utils import TOKEN, USER, is_staging_test
+from ..testing_utils import CaptureLogger, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 4e7bc0af6842..52fd2f5bfc7f 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -1,3 +1,4 @@
+import gc
import tempfile
from io import BytesIO
@@ -7,8 +8,12 @@
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.models.attention_processor import AttnProcessor
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
+ backend_empty_cache,
+ nightly,
numpy_cosine_similarity_distance,
+ require_torch_accelerator,
torch_device,
)
@@ -46,6 +51,93 @@ def download_diffusers_config(repo_id, tmpdir):
return path
+@nightly
+@require_torch_accelerator
+class SingleFileModelTesterMixin:
+ def setup_method(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def teardown_method(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_model_config(self):
+ pretrained_kwargs = {}
+ single_file_kwargs = {}
+
+ if hasattr(self, "subfolder") and self.subfolder:
+ pretrained_kwargs["subfolder"] = self.subfolder
+
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ pretrained_kwargs["torch_dtype"] = self.torch_dtype
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
+ model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
+
+ def test_single_file_model_parameters(self):
+ pretrained_kwargs = {}
+ single_file_kwargs = {}
+
+ if hasattr(self, "subfolder") and self.subfolder:
+ pretrained_kwargs["subfolder"] = self.subfolder
+
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ pretrained_kwargs["torch_dtype"] = self.torch_dtype
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
+ model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
+
+ state_dict = model.state_dict()
+ state_dict_single_file = model_single_file.state_dict()
+
+ assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
+ "Model parameters keys differ between pretrained and single file loading"
+ )
+
+ for key in state_dict.keys():
+ param = state_dict[key]
+ param_single_file = state_dict_single_file[key]
+
+ assert param.shape == param_single_file.shape, (
+ f"Parameter shape mismatch for {key}: "
+ f"pretrained {param.shape} vs single file {param_single_file.shape}"
+ )
+
+ assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
+ f"Parameter values differ for {key}: "
+ f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
+ )
+
+ def test_checkpoint_altered_keys_loading(self):
+ # Test loading with checkpoints that have altered keys
+ if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
+ return
+
+ for ckpt_path in self.alternate_keys_ckpt_paths:
+ backend_empty_cache(torch_device)
+
+ single_file_kwargs = {}
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
+
+ del model
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+
class SDSingleFileTesterMixin:
single_file_kwargs = {}
@@ -72,9 +164,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
- assert isinstance(
- component, pipe.components[component_name].__class__
- ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ assert isinstance(component, pipe.components[component_name].__class__), (
+ f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ )
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
@@ -85,9 +177,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
- assert (
- pipe.components[component_name].config[param_name] == param_value
- ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ assert pipe.components[component_name].config[param_name] == param_value, (
+ f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ )
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
@@ -253,9 +345,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
continue
assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline"
- assert isinstance(
- component, pipe.components[component_name].__class__
- ), f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ assert isinstance(component, pipe.components[component_name].__class__), (
+ f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same"
+ )
for param_name, param_value in component.config.items():
if param_name in PARAMS_TO_IGNORE:
@@ -266,9 +358,9 @@ def _compare_component_configs(self, pipe, single_file_pipe):
if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None:
pipe.components[component_name].config[param_name] = param_value
- assert (
- pipe.components[component_name].config[param_name] == param_value
- ), f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ assert pipe.components[component_name].config[param_name] == param_value, (
+ f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}"
+ )
def test_single_file_components(self, pipe=None, single_file_pipe=None):
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py
index 78e68c4c2df0..bb5a0bf473b6 100644
--- a/tests/single_file/test_lumina2_transformer.py
+++ b/tests/single_file/test_lumina2_transformer.py
@@ -13,27 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
-
-import torch
from diffusers import (
Lumina2Transformer2DModel,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
+class TestLumina2Transformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = Lumina2Transformer2DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
alternate_keys_ckpt_paths = [
@@ -41,34 +35,4 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
]
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- torch.cuda.empty_cache()
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- torch.cuda.empty_cache()
+ subfolder = "transformer"
diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py
index b1faeb78776b..444ca4046977 100644
--- a/tests/single_file/test_model_autoencoder_dc_single_file.py
+++ b/tests/single_file/test_model_autoencoder_dc_single_file.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,47 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
from diffusers import (
AutoencoderDC,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
- require_torch_accelerator,
- slow,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class AutoencoderDCSingleFileTests(unittest.TestCase):
+class TestAutoencoderDCSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderDC
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
main_input_name = "sample"
base_precision = 1e-2
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -79,23 +64,11 @@ def test_single_file_inference_same_as_pretrained(self):
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
-
def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter
# in order to set the scaling factor correctly.
# `in` and `mix` variants have the same keys and we cannot automatically infer a scaling factor.
- # We default to using teh `mix` config
+ # We default to using the `mix` config
repo_id = "mit-han-lab/dc-ae-f128c512-in-1.0-diffusers"
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors"
@@ -106,9 +79,9 @@ def test_single_file_in_type_variant_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_mix_type_variant_components(self):
repo_id = "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"
@@ -121,6 +94,6 @@ def test_single_file_mix_type_variant_components(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
diff --git a/tests/single_file/test_model_controlnet_single_file.py b/tests/single_file/test_model_controlnet_single_file.py
index bfcb802380a6..2fa81fe3ae55 100644
--- a/tests/single_file/test_model_controlnet_single_file.py
+++ b/tests/single_file/test_model_controlnet_single_file.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,55 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
from diffusers import (
ControlNetModel,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class ControlNetModelSingleFileTests(unittest.TestCase):
+class TestControlNetModelSingleFile(SingleFileModelTesterMixin):
model_class = ControlNetModel
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
repo_id = "lllyasviel/control_v11p_sd15_canny"
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
-
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path)
diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py
index 0ec97db26a9e..0642a71c5756 100644
--- a/tests/single_file/test_model_flux_transformer_single_file.py
+++ b/tests/single_file/test_model_flux_transformer_single_file.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,59 +14,34 @@
# limitations under the License.
import gc
-import unittest
-
-import torch
from diffusers import (
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
- require_torch_accelerator,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
+class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = FluxTransformer2DModel
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
repo_id = "black-forest-labs/FLUX.1-dev"
+ subfolder = "transformer"
- def setUp(self):
- super().setUp()
- gc.collect()
+ def test_device_map_cuda(self):
backend_empty_cache(torch_device)
+ model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda")
- def tearDown(self):
- super().tearDown()
+ del model
gc.collect()
backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- torch.cuda.empty_cache()
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- torch.cuda.empty_cache()
diff --git a/tests/single_file/test_model_motion_adapter_single_file.py b/tests/single_file/test_model_motion_adapter_single_file.py
index b195f25d094b..a047c81b47aa 100644
--- a/tests/single_file/test_model_motion_adapter_single_file.py
+++ b/tests/single_file/test_model_motion_adapter_single_file.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
from diffusers import (
MotionAdapter,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
enable_full_determinism,
)
@@ -26,7 +26,7 @@
enable_full_determinism()
-class MotionAdapterSingleFileTests(unittest.TestCase):
+class MotionAdapterSingleFileTests:
model_class = MotionAdapter
def test_single_file_components_version_v1_5(self):
@@ -40,9 +40,9 @@ def test_single_file_components_version_v1_5(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_v1_5_2(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sd_v15_v2.ckpt"
@@ -55,9 +55,9 @@ def test_single_file_components_version_v1_5_2(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_v1_5_3(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt"
@@ -70,9 +70,9 @@ def test_single_file_components_version_v1_5_3(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
def test_single_file_components_version_sdxl_beta(self):
ckpt_path = "https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt"
@@ -85,6 +85,6 @@ def test_single_file_components_version_sdxl_beta(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
diff --git a/tests/single_file/test_model_sd_cascade_unet_single_file.py b/tests/single_file/test_model_sd_cascade_unet_single_file.py
index 08b04e3cd7e8..7472122710eb 100644
--- a/tests/single_file/test_model_sd_cascade_unet_single_file.py
+++ b/tests/single_file/test_model_sd_cascade_unet_single_file.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,13 +14,13 @@
# limitations under the License.
import gc
-import unittest
import torch
from diffusers import StableCascadeUNet
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
@@ -36,14 +36,12 @@
@slow
@require_torch_accelerator
-class StableCascadeUNetSingleFileTest(unittest.TestCase):
- def setUp(self):
- super().setUp()
+class StableCascadeUNetSingleFileTest:
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -60,9 +58,9 @@ def test_single_file_components_stage_b(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_b_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -77,9 +75,9 @@ def test_single_file_components_stage_b_lite(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_c(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -94,9 +92,9 @@ def test_single_file_components_stage_c(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
def test_single_file_components_stage_c_lite(self):
model_single_file = StableCascadeUNet.from_single_file(
@@ -111,6 +109,6 @@ def test_single_file_components_stage_c_lite(self):
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between single file loading and pretrained loading"
+ )
diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py
index 9db4cddb3c9d..9198d9b16337 100644
--- a/tests/single_file/test_model_vae_single_file.py
+++ b/tests/single_file/test_model_vae_single_file.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 HuggingFace Inc.
+# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,31 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
from diffusers import (
AutoencoderKL,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
- require_torch_accelerator,
- slow,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class AutoencoderKLSingleFileTests(unittest.TestCase):
+class TestAutoencoderKLSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderKL
ckpt_path = (
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
@@ -46,16 +41,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
main_input_name = "sample"
base_precision = 1e-2
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -83,18 +68,6 @@ def test_single_file_inference_same_as_pretrained(self):
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between pretrained loading and single file loading"
-
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py
index f5720ddd3964..0babf302348f 100644
--- a/tests/single_file/test_model_wan_autoencoder_single_file.py
+++ b/tests/single_file/test_model_wan_autoencoder_single_file.py
@@ -13,49 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
from diffusers import (
AutoencoderKLWan,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class AutoencoderKLWanSingleFileTests(unittest.TestCase):
+class TestAutoencoderKLWanSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderKLWan
ckpt_path = (
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ subfolder = "vae"
diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py
index 9b938aa1754c..b76909206073 100644
--- a/tests/single_file/test_model_wan_transformer3d_single_file.py
+++ b/tests/single_file/test_model_wan_transformer3d_single_file.py
@@ -13,81 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
from diffusers import (
WanTransformer3DModel,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_big_gpu_with_torch_cuda,
- require_torch_accelerator,
- torch_device,
+ require_big_accelerator,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
+class TestWanTransformer3DModelText2VideoSingleFile(SingleFileModelTesterMixin):
model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ subfolder = "transformer"
-@require_big_gpu_with_torch_cuda
-@require_torch_accelerator
-class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
+@require_big_accelerator
+class TestWanTransformer3DModelImage2VideoSingleFile(SingleFileModelTesterMixin):
model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
torch_dtype = torch.float8_e4m3fn
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype)
- model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
+ subfolder = "transformer"
diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py
index 7695e1577711..9e2adb93bf2b 100644
--- a/tests/single_file/test_sana_transformer.py
+++ b/tests/single_file/test_sana_transformer.py
@@ -1,24 +1,17 @@
-import gc
-import unittest
-
-import torch
-
from diffusers import (
SanaTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
+class TestSanaTransformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = SanaTransformer2DModel
ckpt_path = (
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
@@ -28,34 +21,4 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
]
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert (
- model.config[param_name] == param_value
- ), f"{param_name} differs between single file loading and pretrained loading"
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- torch.cuda.empty_cache()
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- torch.cuda.empty_cache()
+ subfolder = "transformer"
diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
index 7589b48028c2..141748b084a0 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
@@ -1,13 +1,13 @@
import gc
import tempfile
-import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -15,7 +15,6 @@
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
@@ -29,7 +28,7 @@
@slow
@require_torch_accelerator
-class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
index 1555831db6db..8238866cbfb3 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
@@ -1,13 +1,14 @@
import gc
import tempfile
-import unittest
+import pytest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -15,7 +16,6 @@
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
@@ -29,19 +29,17 @@
@slow
@require_torch_accelerator
-class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetInpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "stable-diffusion-v1-5/stable-diffusion-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -115,7 +113,7 @@ def test_single_file_components_local_files_only(self):
super()._compare_component_configs(pipe, pipe_single_file)
- @unittest.skip("runwayml original config repo does not exist")
+ @pytest.mark.skip(reason="runwayml original config repo does not exist")
def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
@@ -125,7 +123,7 @@ def test_single_file_components_with_original_config(self):
super()._compare_component_configs(pipe, pipe_single_file)
- @unittest.skip("runwayml original config repo does not exist")
+ @pytest.mark.skip(reason="runwayml original config repo does not exist")
def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
index 2c1e414e5e36..80ef6c2574c2 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
@@ -1,13 +1,13 @@
import gc
import tempfile
-import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -15,7 +15,6 @@
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
@@ -29,7 +28,7 @@
@slow
@require_torch_accelerator
-class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_img2img_single_file.py b/tests/single_file/test_stable_diffusion_img2img_single_file.py
index 9ad935582409..e76846c800a8 100644
--- a/tests/single_file/test_stable_diffusion_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_img2img_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import torch
@@ -7,14 +6,14 @@
StableDiffusionImg2ImgPipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -23,7 +22,7 @@
@slow
@require_torch_accelerator
-class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionImg2ImgPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -33,13 +32,11 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -66,19 +63,17 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
@slow
@require_torch_accelerator
-class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21Img2ImgPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_inpaint_single_file.py
index b05a098c0bcb..6e5d27cdffef 100644
--- a/tests/single_file/test_stable_diffusion_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_inpaint_single_file.py
@@ -1,20 +1,20 @@
import gc
-import unittest
+import pytest
import torch
from diffusers import (
StableDiffusionInpaintPipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -23,19 +23,17 @@
@slow
@require_torch_accelerator
-class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionInpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "botp/stable-diffusion-v1-5-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -70,18 +68,18 @@ def test_single_file_loading_4_channel_unet(self):
assert pipe.unet.config.in_channels == 4
- @unittest.skip("runwayml original config has been removed")
+ @pytest.mark.skip(reason="runwayml original config has been removed")
def test_single_file_components_with_original_config(self):
return
- @unittest.skip("runwayml original config has been removed")
+ @pytest.mark.skip(reason="runwayml original config has been removed")
def test_single_file_components_with_original_config_local_files_only(self):
return
@slow
@require_torch_accelerator
-class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21InpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/512-inpainting-ema.safetensors"
@@ -89,13 +87,11 @@ class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDS
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inpainting-inference.yaml"
repo_id = "stabilityai/stable-diffusion-2-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py
index 78baeb94929c..377dedbc5731 100644
--- a/tests/single_file/test_stable_diffusion_single_file.py
+++ b/tests/single_file/test_stable_diffusion_single_file.py
@@ -1,13 +1,13 @@
import gc
import tempfile
-import unittest
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
nightly,
@@ -15,7 +15,6 @@
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_original_config,
@@ -28,7 +27,7 @@
@slow
@require_torch_accelerator
-class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -38,13 +37,11 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -90,19 +87,17 @@ def test_single_file_legacy_scaling_factor(self):
@slow
-class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21PipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -125,7 +120,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
@nightly
@slow
@require_torch_accelerator
-class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionInstructPix2PixPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
original_config = (
@@ -134,13 +129,11 @@ class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCas
repo_id = "timbrooks/instruct-pix2pix"
single_file_kwargs = {"extract_ema": True}
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_upscale_single_file.py b/tests/single_file/test_stable_diffusion_upscale_single_file.py
index 398fc9ece359..ba4819fadf85 100644
--- a/tests/single_file/test_stable_diffusion_upscale_single_file.py
+++ b/tests/single_file/test_stable_diffusion_upscale_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import pytest
import torch
@@ -8,7 +7,8 @@
StableDiffusionUpscalePipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -16,7 +16,6 @@
slow,
torch_device,
)
-
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -25,19 +24,17 @@
@slow
@require_torch_accelerator
-class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionUpscalePipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionUpscalePipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
repo_id = "stabilityai/stable-diffusion-x4-upscaler"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
index fb5f8725b86e..3d124fa8c23c 100644
--- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
@@ -1,6 +1,5 @@
import gc
import tempfile
-import unittest
import torch
@@ -10,7 +9,8 @@
)
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -18,7 +18,6 @@
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDXLSingleFileTesterMixin,
download_diffusers_config,
@@ -32,7 +31,7 @@
@slow
@require_torch_accelerator
-class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLAdapterPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLAdapterPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -40,13 +39,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
index 6d8c4369e1e1..6f503702610a 100644
--- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
@@ -1,13 +1,13 @@
import gc
import tempfile
-import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -15,7 +15,6 @@
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDXLSingleFileTesterMixin,
download_diffusers_config,
@@ -28,7 +27,7 @@
@slow
@require_torch_accelerator
-class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLControlNetPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLControlNetPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -36,13 +35,11 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py b/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
index 7df8b84bc235..56657f37d912 100644
--- a/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import torch
@@ -8,7 +7,8 @@
StableDiffusionXLImg2ImgPipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -16,7 +16,6 @@
slow,
torch_device,
)
-
from .single_file_testing_utils import SDXLSingleFileTesterMixin
@@ -25,7 +24,7 @@
@slow
@require_torch_accelerator
-class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLImg2ImgPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -33,13 +32,11 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -66,7 +63,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self):
@slow
@require_torch_accelerator
-class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests(unittest.TestCase):
+class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests:
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
diff --git a/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py b/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
index 5a014638633b..d755b7010516 100644
--- a/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
+++ b/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
@@ -1,10 +1,10 @@
import gc
-import unittest
import torch
from diffusers import StableDiffusionXLInstructPix2PixPipeline
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
@@ -18,19 +18,17 @@
@slow
@require_torch_accelerator
-class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
+class StableDiffusionXLInstructPix2PixPipeline:
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
original_config = None
repo_id = "diffusers/sdxl-instructpix2pix-768"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_single_file.py b/tests/single_file/test_stable_diffusion_xl_single_file.py
index 77f58d859209..4e5319ca25c7 100644
--- a/tests/single_file/test_stable_diffusion_xl_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_single_file.py
@@ -1,19 +1,18 @@
import gc
-import unittest
import torch
from diffusers import (
StableDiffusionXLPipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from .single_file_testing_utils import SDXLSingleFileTesterMixin
@@ -22,7 +21,7 @@
@slow
@require_torch_accelerator
-class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -30,13 +29,11 @@ class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingle
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
new file mode 100644
index 000000000000..4550813259af
--- /dev/null
+++ b/tests/testing_utils.py
@@ -0,0 +1,1576 @@
+import functools
+import glob
+import importlib
+import importlib.metadata
+import inspect
+import io
+import logging
+import multiprocessing
+import os
+import random
+import re
+import struct
+import sys
+import tempfile
+import time
+import urllib.parse
+from collections import UserDict
+from contextlib import contextmanager
+from io import BytesIO, StringIO
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import PIL.ImageOps
+import pytest
+import requests
+from numpy.linalg import norm
+from packaging import version
+
+from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
+from diffusers.utils.import_utils import (
+ BACKENDS_MAPPING,
+ is_accelerate_available,
+ is_bitsandbytes_available,
+ is_compel_available,
+ is_flax_available,
+ is_gguf_available,
+ is_kernels_available,
+ is_note_seq_available,
+ is_onnx_available,
+ is_opencv_available,
+ is_optimum_quanto_available,
+ is_peft_available,
+ is_timm_available,
+ is_torch_available,
+ is_torch_version,
+ is_torchao_available,
+ is_torchsde_available,
+ is_transformers_available,
+)
+from diffusers.utils.logging import get_logger
+
+
+if is_torch_available():
+ import torch
+
+ IS_ROCM_SYSTEM = torch.version.hip is not None
+ IS_CUDA_SYSTEM = torch.version.cuda is not None
+ IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
+else:
+ IS_ROCM_SYSTEM = False
+ IS_CUDA_SYSTEM = False
+ IS_XPU_SYSTEM = False
+
+IS_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" and os.getenv("DIFFUSERS_IS_CI") == "yes"
+
+global_rng = random.Random()
+
+logger = get_logger(__name__)
+
+_required_peft_version = is_peft_available() and version.parse(
+ version.parse(importlib.metadata.version("peft")).base_version
+) > version.parse("0.5")
+_required_transformers_version = is_transformers_available() and version.parse(
+ version.parse(importlib.metadata.version("transformers")).base_version
+) > version.parse("4.33")
+
+USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
+BIG_GPU_MEMORY = int(os.getenv("BIG_GPU_MEMORY", 40))
+
+if is_torch_available():
+ import torch
+
+ # Set a backend environment variable for any extra module import required for a custom accelerator
+ if "DIFFUSERS_TEST_BACKEND" in os.environ:
+ backend = os.environ["DIFFUSERS_TEST_BACKEND"]
+ try:
+ _ = importlib.import_module(backend)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
+ to enable a specified backend.):\n{e}"
+ ) from e
+
+ if "DIFFUSERS_TEST_DEVICE" in os.environ:
+ torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
+ try:
+ # try creating device to see if provided device is valid
+ _ = torch.device(torch_device)
+ except RuntimeError as e:
+ raise RuntimeError(
+ f"Unknown testing device specified by environment variable `DIFFUSERS_TEST_DEVICE`: {torch_device}"
+ ) from e
+ logger.info(f"torch_device overrode to {torch_device}")
+ else:
+ if torch.cuda.is_available():
+ torch_device = "cuda"
+ elif torch.xpu.is_available():
+ torch_device = "xpu"
+ else:
+ torch_device = "cpu"
+ is_torch_higher_equal_than_1_12 = version.parse(
+ version.parse(torch.__version__).base_version
+ ) >= version.parse("1.12")
+
+ if is_torch_higher_equal_than_1_12:
+ # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details
+ mps_backend_registered = hasattr(torch.backends, "mps")
+ torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
+
+ from diffusers.utils.torch_utils import get_torch_cuda_device_capability
+
+
+def torch_all_close(a, b, *args, **kwargs):
+ if not is_torch_available():
+ raise ValueError("PyTorch needs to be installed to use this function.")
+ if not torch.allclose(a, b, *args, **kwargs):
+ assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}."
+ return True
+
+
+def numpy_cosine_similarity_distance(a, b):
+ similarity = np.dot(a, b) / (norm(a) * norm(b))
+ distance = 1.0 - similarity.mean()
+
+ return distance
+
+
+def check_if_dicts_are_equal(dict1, dict2):
+ dict1, dict2 = dict1.copy(), dict2.copy()
+
+ for key, value in dict1.items():
+ if isinstance(value, set):
+ dict1[key] = sorted(value)
+ for key, value in dict2.items():
+ if isinstance(value, set):
+ dict2[key] = sorted(value)
+
+ for key in dict1:
+ if key not in dict2:
+ return False
+ if dict1[key] != dict2[key]:
+ return False
+
+ for key in dict2:
+ if key not in dict1:
+ return False
+
+ return True
+
+
+def print_tensor_test(
+ tensor,
+ limit_to_slices=None,
+ max_torch_print=None,
+ filename="test_corrections.txt",
+ expected_tensor_name="expected_slice",
+):
+ if max_torch_print:
+ torch.set_printoptions(threshold=10_000)
+
+ test_name = os.environ.get("PYTEST_CURRENT_TEST")
+ if not torch.is_tensor(tensor):
+ tensor = torch.from_numpy(tensor)
+ if limit_to_slices:
+ tensor = tensor[0, -3:, -3:, -1]
+
+ tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
+ # format is usually:
+ # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161])
+ output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array")
+ test_file, test_class, test_fn = test_name.split("::")
+ test_fn = test_fn.split()[0]
+ with open(filename, "a") as f:
+ print("::".join([test_file, test_class, test_fn, output_str]), file=f)
+
+
+def get_tests_dir(append_path=None):
+ """
+ Args:
+ append_path: optional path to append to the tests dir path
+ Return:
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+ joined after the `tests` dir the former is provided.
+ """
+ # this function caller's __file__
+ caller__file__ = inspect.stack()[1][1]
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+ while not tests_dir.endswith("tests"):
+ tests_dir = os.path.dirname(tests_dir)
+
+ if append_path:
+ return Path(tests_dir, append_path).as_posix()
+ else:
+ return tests_dir
+
+
+# Taken from the following PR:
+# https://github.com/huggingface/accelerate/pull/1964
+def str_to_bool(value) -> int:
+ """
+ Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`,
+ `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
+ """
+ value = value.lower()
+ if value in ("y", "yes", "t", "true", "on", "1"):
+ return 1
+ elif value in ("n", "no", "f", "false", "off", "0"):
+ return 0
+ else:
+ raise ValueError(f"invalid truth value {value}")
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = str_to_bool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
+_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.random() * scale)
+
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
+
+
+def nightly(test_case):
+ """
+ Decorator marking a test that runs nightly in the diffusers CI.
+
+ Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
+
+ """
+ return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case)
+
+
+def is_torch_compile(test_case):
+ """
+ Decorator marking a test that runs compile tests in the diffusers CI.
+
+ Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
+
+ """
+ return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
+
+
+def require_torch(test_case):
+ """
+ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
+ """
+ return pytest.mark.skipif(not is_torch_available(), reason="test requires PyTorch")(test_case)
+
+
+def require_torch_2(test_case):
+ """
+ Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
+ """
+ return pytest.mark.skipif(
+ not (is_torch_available() and is_torch_version(">=", "2.0.0")), reason="test requires PyTorch 2"
+ )(test_case)
+
+
+def require_torch_version_greater_equal(torch_version):
+ """Decorator marking a test that requires torch with a specific version or greater."""
+
+ def decorator(test_case):
+ correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
+ return pytest.mark.skipif(
+ not correct_torch_version,
+ reason=f"test requires torch with the version greater than or equal to {torch_version}",
+ )(test_case)
+
+ return decorator
+
+
+def require_torch_version_greater(torch_version):
+ """Decorator marking a test that requires torch with a specific version greater."""
+
+ def decorator(test_case):
+ correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
+ return pytest.mark.skipif(
+ not correct_torch_version, reason=f"test requires torch with the version greater than {torch_version}"
+ )(test_case)
+
+ return decorator
+
+
+def require_torch_gpu(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ return pytest.mark.skipif(torch_device != "cuda", reason="test requires PyTorch+CUDA")(test_case)
+
+
+def require_torch_cuda_compatibility(expected_compute_capability):
+ def decorator(test_case):
+ if torch.cuda.is_available():
+ current_compute_capability = get_torch_cuda_device_capability()
+ return pytest.mark.skipif(
+ float(current_compute_capability) != float(expected_compute_capability),
+ reason="Test not supported for this compute capability.",
+ )(test_case)
+ return test_case
+
+ return decorator
+
+
+# These decorators are for accelerator-specific behaviours that are not GPU-specific
+def require_torch_accelerator(test_case):
+ """Decorator marking a test that requires an accelerator backend and PyTorch."""
+ return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case)
+
+
+def require_torch_multi_gpu(test_case):
+ """
+ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
+ -k "multi_gpu"
+ """
+ if not is_torch_available():
+ return pytest.mark.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ return pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="test requires multiple GPUs")(test_case)
+
+
+def require_torch_multi_accelerator(test_case):
+ """
+ Decorator marking a test that requires a multi-accelerator setup (in PyTorch). These tests are skipped on a machine
+ without multiple hardware accelerators.
+ """
+ if not is_torch_available():
+ return pytest.mark.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ return pytest.mark.skipif(
+ not (torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1),
+ reason="test requires multiple hardware accelerators",
+ )(test_case)
+
+
+def require_torch_accelerator_with_fp16(test_case):
+ """Decorator marking a test that requires an accelerator with support for the FP16 data type."""
+ return pytest.mark.skipif(
+ not _is_torch_fp16_available(torch_device), reason="test requires accelerator with fp16 support"
+ )(test_case)
+
+
+def require_torch_accelerator_with_fp64(test_case):
+ """Decorator marking a test that requires an accelerator with support for the FP64 data type."""
+ return pytest.mark.skipif(
+ not _is_torch_fp64_available(torch_device), reason="test requires accelerator with fp64 support"
+ )(test_case)
+
+
+def require_big_gpu_with_torch_cuda(test_case):
+ """
+ Decorator marking a test that requires a bigger GPU (24GB) for execution. Some example pipelines: Flux, SD3, Cog,
+ etc.
+ """
+ if not is_torch_available():
+ return pytest.mark.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ if not torch.cuda.is_available():
+ return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
+
+ device_properties = torch.cuda.get_device_properties(0)
+ total_memory = device_properties.total_memory / (1024**3)
+ return pytest.mark.skipif(
+ total_memory < BIG_GPU_MEMORY, reason=f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
+ )(test_case)
+
+
+def require_big_accelerator(test_case):
+ """
+ Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
+ Flux, SD3, Cog, etc.
+ """
+ import pytest
+
+ test_case = pytest.mark.big_accelerator(test_case)
+
+ if not is_torch_available():
+ return pytest.mark.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ if not (torch.cuda.is_available() or torch.xpu.is_available()):
+ return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
+
+ if torch.xpu.is_available():
+ device_properties = torch.xpu.get_device_properties(0)
+ else:
+ device_properties = torch.cuda.get_device_properties(0)
+
+ total_memory = device_properties.total_memory / (1024**3)
+ return pytest.mark.skipif(
+ total_memory < BIG_GPU_MEMORY,
+ reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
+ )(test_case)
+
+
+def require_torch_accelerator_with_training(test_case):
+ """Decorator marking a test that requires an accelerator with support for training."""
+ return pytest.mark.skipif(
+ not (is_torch_available() and backend_supports_training(torch_device)),
+ reason="test requires accelerator with training support",
+ )(test_case)
+
+
+def skip_mps(test_case):
+ """Decorator marking a test to skip if torch_device is 'mps'"""
+ return pytest.mark.skipif(torch_device == "mps", reason="test requires non 'mps' device")(test_case)
+
+
+def require_flax(test_case):
+ """
+ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
+ """
+ return pytest.mark.skipif(not is_flax_available(), reason="test requires JAX & Flax")(test_case)
+
+
+def require_compel(test_case):
+ """
+ Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
+ the library is not installed.
+ """
+ return pytest.mark.skipif(not is_compel_available(), reason="test requires compel")(test_case)
+
+
+def require_onnxruntime(test_case):
+ """
+ Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
+ """
+ return pytest.mark.skipif(not is_onnx_available(), reason="test requires onnxruntime")(test_case)
+
+
+def require_note_seq(test_case):
+ """
+ Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
+ """
+ return pytest.mark.skipif(not is_note_seq_available(), reason="test requires note_seq")(test_case)
+
+
+def require_accelerator(test_case):
+ """
+ Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
+ hardware accelerator available.
+ """
+ return pytest.mark.skipif(torch_device == "cpu", reason="test requires a hardware accelerator")(test_case)
+
+
+def require_torchsde(test_case):
+ """
+ Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
+ """
+ return pytest.mark.skipif(not is_torchsde_available(), reason="test requires torchsde")(test_case)
+
+
+def require_peft_backend(test_case):
+ """
+ Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
+ transformers.
+ """
+ return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case)
+
+
+def require_timm(test_case):
+ """
+ Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
+ """
+ return pytest.mark.skipif(not is_timm_available(), reason="test requires timm")(test_case)
+
+
+def require_bitsandbytes(test_case):
+ """
+ Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
+ """
+ return pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")(test_case)
+
+
+def require_quanto(test_case):
+ """
+ Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
+ """
+ return pytest.mark.skipif(not is_optimum_quanto_available(), reason="test requires quanto")(test_case)
+
+
+def require_accelerate(test_case):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
+
+
+def require_peft_version_greater(peft_version):
+ """
+ Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
+ versions of PEFT and transformers.
+ """
+
+ def decorator(test_case):
+ correct_peft_version = is_peft_available() and version.parse(
+ version.parse(importlib.metadata.version("peft")).base_version
+ ) > version.parse(peft_version)
+ return pytest.mark.skipif(
+ not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}"
+ )(test_case)
+
+ return decorator
+
+
+def require_transformers_version_greater(transformers_version):
+ """
+ Decorator marking a test that requires transformers with a specific version, this would require some specific
+ versions of PEFT and transformers.
+ """
+
+ def decorator(test_case):
+ correct_transformers_version = is_transformers_available() and version.parse(
+ version.parse(importlib.metadata.version("transformers")).base_version
+ ) > version.parse(transformers_version)
+ return pytest.mark.skipif(
+ not correct_transformers_version,
+ reason=f"test requires transformers with the version greater than {transformers_version}",
+ )(test_case)
+
+ return decorator
+
+
+def require_accelerate_version_greater(accelerate_version):
+ def decorator(test_case):
+ correct_accelerate_version = is_accelerate_available() and version.parse(
+ version.parse(importlib.metadata.version("accelerate")).base_version
+ ) > version.parse(accelerate_version)
+ return pytest.mark.skipif(
+ not correct_accelerate_version,
+ reason=f"Test requires accelerate with the version greater than {accelerate_version}.",
+ )(test_case)
+
+ return decorator
+
+
+def require_bitsandbytes_version_greater(bnb_version):
+ def decorator(test_case):
+ correct_bnb_version = is_bitsandbytes_available() and version.parse(
+ version.parse(importlib.metadata.version("bitsandbytes")).base_version
+ ) > version.parse(bnb_version)
+ return pytest.mark.skipif(
+ not correct_bnb_version, reason=f"Test requires bitsandbytes with the version greater than {bnb_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_hf_hub_version_greater(hf_hub_version):
+ def decorator(test_case):
+ correct_hf_hub_version = version.parse(
+ version.parse(importlib.metadata.version("huggingface_hub")).base_version
+ ) > version.parse(hf_hub_version)
+ return pytest.mark.skipif(
+ not correct_hf_hub_version,
+ reason=f"Test requires huggingface_hub with the version greater than {hf_hub_version}.",
+ )(test_case)
+
+ return decorator
+
+
+def require_gguf_version_greater_or_equal(gguf_version):
+ def decorator(test_case):
+ correct_gguf_version = is_gguf_available() and version.parse(
+ version.parse(importlib.metadata.version("gguf")).base_version
+ ) >= version.parse(gguf_version)
+ return pytest.mark.skipif(
+ not correct_gguf_version, reason=f"Test requires gguf with the version greater than {gguf_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_torchao_version_greater_or_equal(torchao_version):
+ def decorator(test_case):
+ correct_torchao_version = is_torchao_available() and version.parse(
+ version.parse(importlib.metadata.version("torchao")).base_version
+ ) >= version.parse(torchao_version)
+ return pytest.mark.skipif(
+ not correct_torchao_version, reason=f"Test requires torchao with version greater than {torchao_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_kernels_version_greater_or_equal(kernels_version):
+ def decorator(test_case):
+ correct_kernels_version = is_kernels_available() and version.parse(
+ version.parse(importlib.metadata.version("kernels")).base_version
+ ) >= version.parse(kernels_version)
+ return pytest.mark.skipif(
+ not correct_kernels_version, reason=f"Test requires kernels with version greater than {kernels_version}."
+ )(test_case)
+
+ return decorator
+
+
+def deprecate_after_peft_backend(test_case):
+ """
+ Decorator marking a test that will be skipped after PEFT backend
+ """
+ return pytest.mark.skipif(USE_PEFT_BACKEND, reason="test skipped in favor of PEFT backend")(test_case)
+
+
+def get_python_version():
+ sys_info = sys.version_info
+ major, minor = sys_info.major, sys_info.minor
+ return major, minor
+
+
+def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
+ if isinstance(arry, str):
+ if local_path is not None:
+ # local_path can be passed to correct images of tests
+ return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
+ elif arry.startswith("http://") or arry.startswith("https://"):
+ response = requests.get(arry, timeout=DIFFUSERS_REQUEST_TIMEOUT)
+ response.raise_for_status()
+ arry = np.load(BytesIO(response.content))
+ elif os.path.isfile(arry):
+ arry = np.load(arry)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path"
+ )
+ elif isinstance(arry, np.ndarray):
+ pass
+ else:
+ raise ValueError(
+ "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a"
+ " ndarray."
+ )
+
+ return arry
+
+
+def load_pt(url: str, map_location: Optional[str] = None, weights_only: Optional[bool] = True):
+ response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
+ response.raise_for_status()
+ arry = torch.load(BytesIO(response.content), map_location=map_location, weights_only=weights_only)
+ return arry
+
+
+def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
+ """
+ Loads `image` to a PIL Image.
+
+ Args:
+ image (`str` or `PIL.Image.Image`):
+ The image to convert to the PIL Image format.
+ Returns:
+ `PIL.Image.Image`:
+ A PIL Image.
+ """
+ if isinstance(image, str):
+ if image.startswith("http://") or image.startswith("https://"):
+ image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
+ elif os.path.isfile(image):
+ image = PIL.Image.open(image)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
+ )
+ elif isinstance(image, PIL.Image.Image):
+ image = image
+ else:
+ raise ValueError(
+ "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
+ )
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def preprocess_image(image: PIL.Image, batch_size: int):
+ w, h = image.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str:
+ if output_gif_path is None:
+ output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name
+
+ image[0].save(
+ output_gif_path,
+ save_all=True,
+ append_images=image[1:],
+ optimize=False,
+ duration=100,
+ loop=0,
+ )
+ return output_gif_path
+
+
+@contextmanager
+def buffered_writer(raw_f):
+ f = io.BufferedWriter(raw_f)
+ yield f
+ f.flush()
+
+
+def export_to_ply(mesh, output_ply_path: str = None):
+ """
+ Write a PLY file for a mesh.
+ """
+ if output_ply_path is None:
+ output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
+
+ coords = mesh.verts.detach().cpu().numpy()
+ faces = mesh.faces.cpu().numpy()
+ rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
+
+ with buffered_writer(open(output_ply_path, "wb")) as f:
+ f.write(b"ply\n")
+ f.write(b"format binary_little_endian 1.0\n")
+ f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
+ f.write(b"property float x\n")
+ f.write(b"property float y\n")
+ f.write(b"property float z\n")
+ if rgb is not None:
+ f.write(b"property uchar red\n")
+ f.write(b"property uchar green\n")
+ f.write(b"property uchar blue\n")
+ if faces is not None:
+ f.write(bytes(f"element face {len(faces)}\n", "ascii"))
+ f.write(b"property list uchar int vertex_index\n")
+ f.write(b"end_header\n")
+
+ if rgb is not None:
+ rgb = (rgb * 255.499).round().astype(int)
+ vertices = [
+ (*coord, *rgb)
+ for coord, rgb in zip(
+ coords.tolist(),
+ rgb.tolist(),
+ )
+ ]
+ format = struct.Struct("<3f3B")
+ for item in vertices:
+ f.write(format.pack(*item))
+ else:
+ format = struct.Struct("<3f")
+ for vertex in coords.tolist():
+ f.write(format.pack(*vertex))
+
+ if faces is not None:
+ format = struct.Struct(" str:
+ if is_opencv_available():
+ import cv2
+ else:
+ raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
+ if output_video_path is None:
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
+
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ h, w, c = video_frames[0].shape
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
+ for i in range(len(video_frames)):
+ img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
+ video_writer.write(img)
+ return output_video_path
+
+
+def load_hf_numpy(path) -> np.ndarray:
+ base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"
+
+ if not path.startswith("http://") and not path.startswith("https://"):
+ path = os.path.join(base_url, urllib.parse.quote(path))
+
+ return load_numpy(path)
+
+
+# --- pytest conf functions --- #
+
+# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
+pytest_opt_registered = {}
+
+
+def pytest_addoption_shared(parser):
+ """
+ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
+
+ It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
+ option.
+
+ """
+ option = "--make-reports"
+ if option not in pytest_opt_registered:
+ parser.addoption(
+ option,
+ action="store",
+ default=False,
+ help="generate report files. The value of this option is used as a prefix to report names",
+ )
+ pytest_opt_registered[option] = 1
+
+
+def pytest_terminal_summary_main(tr, id):
+ """
+ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
+ directory. The report files are prefixed with the test suite name.
+
+ This function emulates --duration and -rA pytest arguments.
+
+ This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
+ there.
+
+ Args:
+ - tr: `terminalreporter` passed from `conftest.py`
+ - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
+ needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
+
+ NB: this functions taps into a private _pytest API and while unlikely, it could break should
+ pytest do internal changes - also it calls default internal methods of terminalreporter which
+ can be hijacked by various `pytest-` plugins and interfere.
+
+ """
+ from _pytest.config import create_terminal_writer
+
+ if not len(id):
+ id = "tests"
+
+ config = tr.config
+ orig_writer = config.get_terminal_writer()
+ orig_tbstyle = config.option.tbstyle
+ orig_reportchars = tr.reportchars
+
+ dir = "reports"
+ Path(dir).mkdir(parents=True, exist_ok=True)
+ report_files = {
+ k: f"{dir}/{id}_{k}.txt"
+ for k in [
+ "durations",
+ "errors",
+ "failures_long",
+ "failures_short",
+ "failures_line",
+ "passes",
+ "stats",
+ "summary_short",
+ "warnings",
+ ]
+ }
+
+ # custom durations report
+ # note: there is no need to call pytest --durations=XX to get this separate report
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
+ dlist = []
+ for replist in tr.stats.values():
+ for rep in replist:
+ if hasattr(rep, "duration"):
+ dlist.append(rep)
+ if dlist:
+ dlist.sort(key=lambda x: x.duration, reverse=True)
+ with open(report_files["durations"], "w") as f:
+ durations_min = 0.05 # sec
+ f.write("slowest durations\n")
+ for i, rep in enumerate(dlist):
+ if rep.duration < durations_min:
+ f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
+ break
+ f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
+
+ def summary_failures_short(tr):
+ # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
+ reports = tr.getreports("failed")
+ if not reports:
+ return
+ tr.write_sep("=", "FAILURES SHORT STACK")
+ for rep in reports:
+ msg = tr._getfailureheadline(rep)
+ tr.write_sep("_", msg, red=True, bold=True)
+ # chop off the optional leading extra frames, leaving only the last one
+ longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
+ tr._tw.line(longrepr)
+ # note: not printing out any rep.sections to keep the report short
+
+ # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
+ # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
+ # pytest-instafail does that)
+
+ # report failures with line/short/long styles
+ config.option.tbstyle = "auto" # full tb
+ with open(report_files["failures_long"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ # config.option.tbstyle = "short" # short tb
+ with open(report_files["failures_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ summary_failures_short(tr)
+
+ config.option.tbstyle = "line" # one line per error
+ with open(report_files["failures_line"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ with open(report_files["errors"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_errors()
+
+ with open(report_files["warnings"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_warnings() # normal warnings
+ tr.summary_warnings() # final warnings
+
+ tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
+ with open(report_files["passes"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_passes()
+
+ with open(report_files["summary_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.short_test_summary()
+
+ with open(report_files["stats"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_stats()
+
+ # restore:
+ tr._tw = orig_writer
+ tr.reportchars = orig_reportchars
+ config.option.tbstyle = orig_tbstyle
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers..testing_utils.py#L1905
+def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
+ """
+ To decorate flaky tests (methods or entire classes). They will be retried on failures.
+
+ Args:
+ max_attempts (`int`, *optional*, defaults to 5):
+ The maximum number of attempts to retry the flaky test.
+ wait_before_retry (`float`, *optional*):
+ If provided, will wait that number of seconds before retrying the test.
+ description (`str`, *optional*):
+ A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
+ etc.)
+ """
+
+ def decorator(obj):
+ # If decorating a class, wrap each test method on it
+ if inspect.isclass(obj):
+ for attr_name, attr_value in list(obj.__dict__.items()):
+ if callable(attr_value) and attr_name.startswith("test"):
+ # recursively decorate the method
+ setattr(obj, attr_name, decorator(attr_value))
+ return obj
+
+ # Otherwise we're decorating a single test function / method
+ @functools.wraps(obj)
+ def wrapper(*args, **kwargs):
+ retry_count = 1
+ while retry_count < max_attempts:
+ try:
+ return obj(*args, **kwargs)
+ except Exception as err:
+ msg = (
+ f"[FLAKY] {description or obj.__name__!r} "
+ f"failed on attempt {retry_count}/{max_attempts}: {err}"
+ )
+ print(msg, file=sys.stderr)
+ if wait_before_retry is not None:
+ time.sleep(wait_before_retry)
+ retry_count += 1
+
+ return obj(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers..testing_utils.py#L1787
+def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
+ """
+ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
+
+ Args:
+ test_case:
+ The test case object that will run `target_func`.
+ target_func (`Callable`):
+ The function implementing the actual testing logic.
+ inputs (`dict`, *optional*, defaults to `None`):
+ The inputs that will be passed to `target_func` through an (input) queue.
+ timeout (`int`, *optional*, defaults to `None`):
+ The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
+ variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
+ """
+ if timeout is None:
+ timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
+
+ start_methohd = "spawn"
+ ctx = multiprocessing.get_context(start_methohd)
+
+ input_queue = ctx.Queue(1)
+ output_queue = ctx.JoinableQueue(1)
+
+ # We can't send test case objects to the child, otherwise we get issues regarding pickle.
+ input_queue.put(inputs, timeout=timeout)
+
+ process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
+ process.start()
+ # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
+ # the test to exit properly.
+ try:
+ results = output_queue.get(timeout=timeout)
+ output_queue.task_done()
+ except Exception as e:
+ process.terminate()
+ test_case.fail(e)
+ process.join(timeout=timeout)
+
+ if results["error"] is not None:
+ test_case.fail(f"{results['error']}")
+
+
+class CaptureLogger:
+ """
+ Args:
+ Context manager to capture `logging` streams
+ logger: 'logging` logger object
+ Returns:
+ The captured output is available via `self.out`
+ Example:
+ ```python
+ >>> from diffusers import logging
+ >>> from diffusers..testing_utils import CaptureLogger
+
+ >>> msg = "Testing 1, 2, 3"
+ >>> logging.set_verbosity_info()
+ >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
+ >>> with CaptureLogger(logger) as cl:
+ ... logger.info(msg)
+ >>> assert cl.out, msg + "\n"
+ ```
+ """
+
+ def __init__(self, logger):
+ self.logger = logger
+ self.io = StringIO()
+ self.sh = logging.StreamHandler(self.io)
+ self.out = ""
+
+ def __enter__(self):
+ self.logger.addHandler(self.sh)
+ return self
+
+ def __exit__(self, *exc):
+ self.logger.removeHandler(self.sh)
+ self.out = self.io.getvalue()
+
+ def __repr__(self):
+ return f"captured: {self.out}\n"
+
+
+def enable_full_determinism():
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cuda.matmul.allow_tf32 = False
+
+
+def disable_full_determinism():
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
+ torch.use_deterministic_algorithms(False)
+
+
+# Utils for custom and alternative accelerator devices
+def _is_torch_fp16_available(device):
+ if not is_torch_available():
+ return False
+
+ import torch
+
+ device = torch.device(device)
+
+ try:
+ x = torch.zeros((2, 2), dtype=torch.float16).to(device)
+ _ = torch.mul(x, x)
+ return True
+
+ except Exception as e:
+ if device.type == "cuda":
+ raise ValueError(
+ f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
+ )
+
+ return False
+
+
+def _is_torch_fp64_available(device):
+ if not is_torch_available():
+ return False
+
+ import torch
+
+ device = torch.device(device)
+
+ try:
+ x = torch.zeros((2, 2), dtype=torch.float64).to(device)
+ _ = torch.mul(x, x)
+ return True
+
+ except Exception as e:
+ if device.type == "cuda":
+ raise ValueError(
+ f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
+ )
+
+ return False
+
+
+# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
+if is_torch_available():
+ # Behaviour flags
+ BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
+
+ # Function definitions
+ BACKEND_EMPTY_CACHE = {
+ "cuda": torch.cuda.empty_cache,
+ "xpu": torch.xpu.empty_cache,
+ "cpu": None,
+ "mps": torch.mps.empty_cache,
+ "default": None,
+ }
+ BACKEND_DEVICE_COUNT = {
+ "cuda": torch.cuda.device_count,
+ "xpu": torch.xpu.device_count,
+ "cpu": lambda: 0,
+ "mps": lambda: 0,
+ "default": 0,
+ }
+ BACKEND_MANUAL_SEED = {
+ "cuda": torch.cuda.manual_seed,
+ "xpu": torch.xpu.manual_seed,
+ "cpu": torch.manual_seed,
+ "mps": torch.mps.manual_seed,
+ "default": torch.manual_seed,
+ }
+ BACKEND_RESET_PEAK_MEMORY_STATS = {
+ "cuda": torch.cuda.reset_peak_memory_stats,
+ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+ BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.reset_max_memory_allocated,
+ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+ BACKEND_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.max_memory_allocated,
+ "xpu": getattr(torch.xpu, "max_memory_allocated", None),
+ "cpu": 0,
+ "mps": 0,
+ "default": 0,
+ }
+ BACKEND_SYNCHRONIZE = {
+ "cuda": torch.cuda.synchronize,
+ "xpu": getattr(torch.xpu, "synchronize", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+
+
+# This dispatches a defined function according to the accelerator from the function definitions.
+def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
+ if device not in dispatch_table:
+ return dispatch_table["default"](*args, **kwargs)
+
+ fn = dispatch_table[device]
+
+ # Some device agnostic functions return values. Need to guard against 'None' instead at
+ # user level
+ if not callable(fn):
+ return fn
+
+ return fn(*args, **kwargs)
+
+
+# These are callables which automatically dispatch the function specific to the accelerator
+def backend_manual_seed(device: str, seed: int):
+ return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
+
+
+def backend_synchronize(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
+
+
+def backend_empty_cache(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
+
+
+def backend_device_count(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
+
+
+def backend_reset_peak_memory_stats(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
+
+
+def backend_reset_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
+
+
+def backend_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
+
+
+# These are callables which return boolean behaviour flags and can be used to specify some
+# device agnostic alternative where the feature is unsupported.
+def backend_supports_training(device: str):
+ if not is_torch_available():
+ return False
+
+ if device not in BACKEND_SUPPORTS_TRAINING:
+ device = "default"
+
+ return BACKEND_SUPPORTS_TRAINING[device]
+
+
+# Guard for when Torch is not available
+if is_torch_available():
+ # Update device function dict mapping
+ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
+ try:
+ # Try to import the function directly
+ spec_fn = getattr(device_spec_module, attribute_name)
+ device_fn_dict[torch_device] = spec_fn
+ except AttributeError as e:
+ # If the function doesn't exist, and there is no default, throw an error
+ if "default" not in device_fn_dict:
+ raise AttributeError(
+ f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
+ ) from e
+
+ if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
+ device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
+ if not Path(device_spec_path).is_file():
+ raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
+
+ try:
+ import_name = device_spec_path[: device_spec_path.index(".py")]
+ except ValueError as e:
+ raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
+
+ device_spec_module = importlib.import_module(import_name)
+
+ try:
+ device_name = device_spec_module.DEVICE_NAME
+ except AttributeError:
+ raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
+
+ if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
+ msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
+ msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
+ raise ValueError(msg)
+
+ torch_device = device_name
+
+ # Add one entry here for each `BACKEND_*` dictionary.
+ update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
+ update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
+ update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
+ update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
+ update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
+ update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
+ update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
+
+
+# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers..testing_utils.py#L3090
+
+# Type definition of key used in `Expectations` class.
+DeviceProperties = Tuple[Union[str, None], Union[int, None]]
+
+
+@functools.lru_cache
+def get_device_properties() -> DeviceProperties:
+ """
+ Get environment device properties.
+ """
+ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
+ import torch
+
+ major, _ = torch.cuda.get_device_capability()
+ if IS_ROCM_SYSTEM:
+ return ("rocm", major)
+ else:
+ return ("cuda", major)
+ elif IS_XPU_SYSTEM:
+ import torch
+
+ # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
+ arch = torch.xpu.get_device_capability()["architecture"]
+ gen_mask = 0x000000FF00000000
+ gen = (arch & gen_mask) >> 32
+ return ("xpu", gen)
+ else:
+ return (torch_device, None)
+
+
+if TYPE_CHECKING:
+ DevicePropertiesUserDict = UserDict[DeviceProperties, Any]
+else:
+ DevicePropertiesUserDict = UserDict
+
+if is_torch_available():
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks.group_offloading import (
+ _GROUP_ID_LAZY_LEAF,
+ _compute_group_hash,
+ _find_parent_module_in_module_dict,
+ _gather_buffers_with_no_group_offloading_parent,
+ _gather_parameters_with_no_group_offloading_parent,
+ )
+
+ def _get_expected_safetensors_files(
+ module: torch.nn.Module,
+ offload_to_disk_path: str,
+ offload_type: str,
+ num_blocks_per_group: Optional[int] = None,
+ block_modules: Optional[List[str]] = None,
+ module_prefix: str = "",
+ ) -> Set[str]:
+ expected_files = set()
+
+ def get_hashed_filename(group_id: str) -> str:
+ short_hash = _compute_group_hash(group_id)
+ return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
+
+ if offload_type == "block_level":
+ if num_blocks_per_group is None:
+ raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
+
+ block_modules_set = set(block_modules) if block_modules is not None else set()
+
+ modules_with_group_offloading = set()
+ unmatched_modules = []
+ for name, submodule in module.named_children():
+ if name in block_modules_set:
+ new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}."
+ submodule_files = _get_expected_safetensors_files(
+ submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix
+ )
+ expected_files.update(submodule_files)
+ modules_with_group_offloading.add(name)
+
+ elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
+ for i in range(0, len(submodule), num_blocks_per_group):
+ current_modules = submodule[i : i + num_blocks_per_group]
+ if not current_modules:
+ continue
+ group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
+ expected_files.add(get_hashed_filename(group_id))
+ for j in range(i, i + len(current_modules)):
+ modules_with_group_offloading.add(f"{name}.{j}")
+ else:
+ unmatched_modules.append(submodule)
+
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
+
+ if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
+ expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group"))
+
+ elif offload_type == "leaf_level":
+ # Handle leaf-level module groups
+ for name, submodule in module.named_modules():
+ if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
+ # These groups will always have parameters, so a file is expected
+ expected_files.add(get_hashed_filename(name))
+
+ # Handle groups for non-leaf parameters/buffers
+ modules_with_group_offloading = {
+ name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
+ }
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
+
+ all_orphans = parameters + buffers
+ if all_orphans:
+ parent_to_tensors = {}
+ module_dict = dict(module.named_modules())
+ for tensor_name, _ in all_orphans:
+ parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
+ if parent_name not in parent_to_tensors:
+ parent_to_tensors[parent_name] = []
+ parent_to_tensors[parent_name].append(tensor_name)
+
+ for parent_name in parent_to_tensors:
+ # A file is expected for each parent that gathers orphaned tensors
+ expected_files.add(get_hashed_filename(parent_name))
+ expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
+
+ else:
+ raise ValueError(f"Unsupported offload_type: {offload_type}")
+
+ return expected_files
+
+ def _check_safetensors_serialization(
+ module: torch.nn.Module,
+ offload_to_disk_path: str,
+ offload_type: str,
+ num_blocks_per_group: Optional[int] = None,
+ block_modules: Optional[List[str]] = None,
+ ) -> bool:
+ if not os.path.isdir(offload_to_disk_path):
+ return False, None, None
+
+ expected_files = _get_expected_safetensors_files(
+ module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules
+ )
+ actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
+ missing_files = expected_files - actual_files
+ extra_files = actual_files - expected_files
+
+ is_correct = not missing_files and not extra_files
+ return is_correct, extra_files, missing_files
+
+
+class Expectations(DevicePropertiesUserDict):
+ def get_expectation(self) -> Any:
+ """
+ Find best matching expectation based on environment device properties.
+ """
+ return self.find_expectation(get_device_properties())
+
+ @staticmethod
+ def is_default(key: DeviceProperties) -> bool:
+ return all(p is None for p in key)
+
+ @staticmethod
+ def score(key: DeviceProperties, other: DeviceProperties) -> int:
+ """
+ Returns score indicating how similar two instances of the `Properties` tuple are. Points are calculated using
+ bits, but documented as int. Rules are as follows:
+ * Matching `type` gives 8 points.
+ * Semi-matching `type`, for example cuda and rocm, gives 4 points.
+ * Matching `major` (compute capability major version) gives 2 points.
+ * Default expectation (if present) gives 1 points.
+ """
+ (device_type, major) = key
+ (other_device_type, other_major) = other
+
+ score = 0b0
+ if device_type == other_device_type:
+ score |= 0b1000
+ elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
+ score |= 0b100
+
+ if major == other_major and other_major is not None:
+ score |= 0b10
+
+ if Expectations.is_default(other):
+ score |= 0b1
+
+ return int(score)
+
+ def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
+ """
+ Find best matching expectation based on provided device properties.
+ """
+ (result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
+
+ if Expectations.score(key, result_key) == 0:
+ raise ValueError(f"No matching expectation found for {key}")
+
+ return result
+
+ def __repr__(self):
+ return f"{self.data}"
diff --git a/utils/check_doc_toc.py b/utils/check_doc_toc.py
index d7c9cee82fcb..050b093991e6 100644
--- a/utils/check_doc_toc.py
+++ b/utils/check_doc_toc.py
@@ -21,20 +21,23 @@
PATH_TO_TOC = "docs/source/en/_toctree.yml"
+# Titles that should maintain their position and not be sorted alphabetically
+FIXED_POSITION_TITLES = {"overview", "autopipeline"}
+
def clean_doc_toc(doc_list):
"""
Cleans the table of content of the model documentation by removing duplicates and sorting models alphabetically.
"""
counts = defaultdict(int)
- overview_doc = []
+ fixed_position_docs = []
new_doc_list = []
for doc in doc_list:
if "local" in doc:
counts[doc["local"]] += 1
- if doc["title"].lower() == "overview":
- overview_doc.append({"local": doc["local"], "title": doc["title"]})
+ if doc["title"].lower() in FIXED_POSITION_TITLES:
+ fixed_position_docs.append({"local": doc["local"], "title": doc["title"]})
else:
new_doc_list.append(doc)
@@ -57,14 +60,13 @@ def clean_doc_toc(doc_list):
new_doc.extend([doc for doc in doc_list if "local" not in counts or counts[doc["local"]] == 1])
new_doc = sorted(new_doc, key=lambda s: s["title"].lower())
- # "overview" gets special treatment and is always first
- if len(overview_doc) > 1:
- raise ValueError("{doc_list} has two 'overview' docs which is not allowed.")
-
- overview_doc.extend(new_doc)
+ # Fixed-position titles maintain their original order
+ result = []
+ for doc in fixed_position_docs:
+ result.append(doc)
- # Sort
- return overview_doc
+ result.extend(new_doc)
+ return result
def check_scheduler_doc(overwrite=False):
@@ -123,11 +125,13 @@ def check_pipeline_doc(overwrite=False):
# sort sub pipeline docs
for pipeline_doc in pipeline_docs:
- if "section" in pipeline_doc:
- sub_pipeline_doc = pipeline_doc["section"]
+ if "sections" in pipeline_doc:
+ sub_pipeline_doc = pipeline_doc["sections"]
new_sub_pipeline_doc = clean_doc_toc(sub_pipeline_doc)
- if overwrite:
- pipeline_doc["section"] = new_sub_pipeline_doc
+ if new_sub_pipeline_doc != sub_pipeline_doc:
+ diff = True
+ if overwrite:
+ pipeline_doc["sections"] = new_sub_pipeline_doc
new_pipeline_docs.append(pipeline_doc)
# sort overall pipeline doc
@@ -149,6 +153,55 @@ def check_pipeline_doc(overwrite=False):
)
+def check_model_doc(overwrite=False):
+ with open(PATH_TO_TOC, encoding="utf-8") as f:
+ content = yaml.safe_load(f.read())
+
+ # Get to the API doc
+ api_idx = 0
+ while content[api_idx]["title"] != "API":
+ api_idx += 1
+ api_doc = content[api_idx]["sections"]
+
+ # Then to the model doc
+ model_idx = 0
+ while api_doc[model_idx]["title"] != "Models":
+ model_idx += 1
+
+ diff = False
+ model_docs = api_doc[model_idx]["sections"]
+ new_model_docs = []
+
+ # sort sub model docs
+ for model_doc in model_docs:
+ if "sections" in model_doc:
+ sub_model_doc = model_doc["sections"]
+ new_sub_model_doc = clean_doc_toc(sub_model_doc)
+ if new_sub_model_doc != sub_model_doc:
+ diff = True
+ if overwrite:
+ model_doc["sections"] = new_sub_model_doc
+ new_model_docs.append(model_doc)
+
+ # sort overall model doc
+ new_model_docs = clean_doc_toc(new_model_docs)
+
+ if new_model_docs != model_docs:
+ diff = True
+ if overwrite:
+ api_doc[model_idx]["sections"] = new_model_docs
+
+ if diff:
+ if overwrite:
+ content[api_idx]["sections"] = api_doc
+ with open(PATH_TO_TOC, "w", encoding="utf-8") as f:
+ f.write(yaml.dump(content, allow_unicode=True))
+ else:
+ raise ValueError(
+ "The model doc part of the table of content is not properly sorted, run `make style` to fix this."
+ )
+
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
@@ -156,3 +209,4 @@ def check_pipeline_doc(overwrite=False):
check_scheduler_doc(args.fix_and_overwrite)
check_pipeline_doc(args.fix_and_overwrite)
+ check_model_doc(args.fix_and_overwrite)
diff --git a/utils/check_support_list.py b/utils/check_support_list.py
index 89cfce62de0b..ade9df3b64fa 100644
--- a/utils/check_support_list.py
+++ b/utils/check_support_list.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Inc. team.
+# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -100,7 +100,7 @@ def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_condit
"doc_path": "docs/source/en/api/loaders/lora.md",
"src_path": "src/diffusers/loaders/lora_pipeline.py",
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
- "src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
+ "src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]",
},
}
diff --git a/utils/consolidated_test_report.py b/utils/consolidated_test_report.py
new file mode 100644
index 000000000000..134fecf721e4
--- /dev/null
+++ b/utils/consolidated_test_report.py
@@ -0,0 +1,789 @@
+#!/usr/bin/env python
+import argparse
+import glob
+import os
+import re
+from datetime import date, datetime
+
+from slack_sdk import WebClient
+from tabulate import tabulate
+
+
+MAX_LEN_MESSAGE = 3001 # slack endpoint has a limit of 3001 characters
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--slack_channel_name", default="diffusers-ci-nightly")
+parser.add_argument(
+ "--reports_dir",
+ default="reports",
+ help="Directory containing test reports (will search recursively in all subdirectories)",
+)
+parser.add_argument("--output_file", default=None, help="Path to save the consolidated report (markdown format)")
+
+
+def parse_stats_file(file_path):
+ """Parse a stats file to extract test statistics."""
+ try:
+ with open(file_path, "r") as f:
+ content = f.read()
+
+ # Extract the numbers using regex
+ tests_pattern = r"collected (\d+) items"
+ passed_pattern = r"(\d+) passed"
+ failed_pattern = r"(\d+) failed"
+ skipped_pattern = r"(\d+) skipped"
+ xpassed_pattern = r"(\d+) xpassed"
+
+ tests_match = re.search(tests_pattern, content)
+ passed_match = re.search(passed_pattern, content)
+ failed_match = re.search(failed_pattern, content)
+ skipped_match = re.search(skipped_pattern, content)
+ xpassed_match = re.search(xpassed_pattern, content)
+
+ passed = int(passed_match.group(1)) if passed_match else 0
+ failed = int(failed_match.group(1)) if failed_match else 0
+ skipped = int(skipped_match.group(1)) if skipped_match else 0
+ xpassed = int(xpassed_match.group(1)) if xpassed_match else 0
+
+ # If tests_match exists, use it, otherwise calculate from passed/failed/skipped
+ if tests_match:
+ tests = int(tests_match.group(1))
+ else:
+ tests = passed + failed + skipped + xpassed
+
+ # Extract timing information if available
+ timing_pattern = r"slowest \d+ test durations[\s\S]*?\n([\s\S]*?)={70}"
+ timing_match = re.search(timing_pattern, content, re.MULTILINE)
+ slowest_tests = []
+
+ if timing_match:
+ timing_text = timing_match.group(1).strip()
+ test_timing_lines = timing_text.split("\n")
+ for line in test_timing_lines:
+ if line.strip():
+ # Format is typically: 10.37s call tests/path/to/test.py::TestClass::test_method
+ parts = line.strip().split()
+ if len(parts) >= 3:
+ time_str = parts[0]
+ test_path = " ".join(parts[2:])
+
+ # Skip entries with "< 0.05 secs were omitted" or similar
+ if "secs were omitted" in test_path:
+ continue
+
+ try:
+ time_seconds = float(time_str.rstrip("s"))
+ slowest_tests.append({"test": test_path, "duration": time_seconds})
+ except ValueError:
+ pass
+
+ return {
+ "tests": tests,
+ "passed": passed,
+ "failed": failed,
+ "skipped": skipped,
+ "slowest_tests": slowest_tests,
+ }
+ except Exception as e:
+ print(f"Error parsing {file_path}: {e}")
+ return {"tests": 0, "passed": 0, "failed": 0, "skipped": 0, "slowest_tests": []}
+
+
+def parse_durations_file(file_path):
+ """Parse a durations file to extract test timing information."""
+ slowest_tests = []
+ try:
+ durations_file = file_path.replace("_stats.txt", "_durations.txt")
+ if os.path.exists(durations_file):
+ with open(durations_file, "r") as f:
+ content = f.read()
+
+ # Skip the header line
+ for line in content.split("\n")[1:]:
+ if line.strip():
+ # Format is typically: 10.37s call tests/path/to/test.py::TestClass::test_method
+ parts = line.strip().split()
+ if len(parts) >= 3:
+ time_str = parts[0]
+ test_path = " ".join(parts[2:])
+
+ # Skip entries with "< 0.05 secs were omitted" or similar
+ if "secs were omitted" in test_path:
+ continue
+
+ try:
+ time_seconds = float(time_str.rstrip("s"))
+ slowest_tests.append({"test": test_path, "duration": time_seconds})
+ except ValueError:
+ # If time_str is not a valid float, it might be a different format
+ # For example, some pytest formats show "< 0.05s" or similar
+ if test_path.startswith("<") and "secs were omitted" in test_path:
+ # Extract the time value from test_path if it's in the format "< 0.05 secs were omitted"
+ try:
+ # This handles entries where the time is in the test_path itself
+ dur_match = re.search(r"(\d+(?:\.\d+)?)", test_path)
+ if dur_match:
+ time_seconds = float(dur_match.group(1))
+ slowest_tests.append({"test": test_path, "duration": time_seconds})
+ except ValueError:
+ pass
+ except Exception as e:
+ print(f"Error parsing durations file {file_path.replace('_stats.txt', '_durations.txt')}: {e}")
+
+ return slowest_tests
+
+
+def parse_failures_file(file_path):
+ """Parse a failures file to extract failed test details."""
+ failures = []
+ try:
+ with open(file_path, "r") as f:
+ content = f.read()
+
+ # We don't need the base file name anymore as we're getting test paths from summary
+
+ # Check if it's a short stack format
+ if "============================= FAILURES SHORT STACK =============================" in content:
+ # First, look for pytest-style failure headers with underscores and clean them up
+ test_headers = re.findall(r"_{5,}\s+([^_\n]+?)\s+_{5,}", content)
+
+ for test_name in test_headers:
+ test_name = test_name.strip()
+ # Make sure it's a valid test name (contains a dot and doesn't look like a number)
+ if "." in test_name and not test_name.replace(".", "").isdigit():
+ # For test names missing the full path, check if we can reconstruct it from failures_line.txt
+ # This is a best effort - we won't always have the line file available
+ if not test_name.endswith(".py") and "::" not in test_name and "/" not in test_name:
+ # Try to look for a corresponding line file
+ line_file = file_path.replace("_failures_short.txt", "_failures_line.txt")
+ if os.path.exists(line_file):
+ try:
+ with open(line_file, "r") as lf:
+ line_content = lf.read()
+ # Look for test name in line file which might have the full path
+ path_match = re.search(
+ r"(tests/[\w/]+\.py::[^:]+::" + test_name.split(".")[-1] + ")",
+ line_content,
+ )
+ if path_match:
+ test_name = path_match.group(1)
+ except Exception:
+ pass # If we can't read the line file, just use what we have
+
+ failures.append(
+ {
+ "test": test_name,
+ "error": "Error occurred",
+ "original_test_name": test_name, # Keep original for reference
+ }
+ )
+
+ # If we didn't find any pytest-style headers, try other formats
+ if not failures:
+ # Look for test names at the beginning of the file (in first few lines)
+ first_lines = content.split("\n")[:20] # Look at first 20 lines
+ for line in first_lines:
+ # Look for test names in various formats
+ # Format: tests/file.py::TestClass::test_method
+ path_match = re.search(r"(tests/[\w/]+\.py::[\w\.]+::\w+)", line)
+ # Format: TestClass.test_method
+ class_match = re.search(r"([A-Za-z][A-Za-z0-9_]+\.[A-Za-z][A-Za-z0-9_]+)", line)
+
+ if path_match:
+ test_name = path_match.group(1)
+ failures.append(
+ {"test": test_name, "error": "Error occurred", "original_test_name": test_name}
+ )
+ break # Found a full path, stop looking
+ elif class_match and "test" in line.lower():
+ test_name = class_match.group(1)
+ # Make sure it's likely a test name (contains test in method name)
+ if "test" in test_name.lower():
+ failures.append(
+ {"test": test_name, "error": "Error occurred", "original_test_name": test_name}
+ )
+ else:
+ # Standard format - try to extract from standard pytest output
+ failure_blocks = re.split(r"={70}", content)
+
+ for block in failure_blocks:
+ if not block.strip():
+ continue
+
+ # Look for test paths in the format: path/to/test.py::TestClass::test_method
+ path_matches = re.findall(r"([\w/]+\.py::[\w\.]+::\w+)", block)
+ if path_matches:
+ for test_name in path_matches:
+ failures.append(
+ {"test": test_name, "error": "Error occurred", "original_test_name": test_name}
+ )
+ else:
+ # Try alternative format: TestClass.test_method
+ class_matches = re.findall(r"([A-Za-z][A-Za-z0-9_]+\.[A-Za-z][A-Za-z0-9_]+)", block)
+ for test_name in class_matches:
+ # Filter out things that don't look like test names
+ if (
+ not test_name.startswith(("e.g", "i.e", "etc."))
+ and not test_name.isdigit()
+ and "test" in test_name.lower()
+ ):
+ failures.append(
+ {"test": test_name, "error": "Error occurred", "original_test_name": test_name}
+ )
+
+ except Exception as e:
+ print(f"Error parsing failures in {file_path}: {e}")
+
+ return failures
+
+
+def consolidate_reports(reports_dir):
+ """Consolidate test reports from multiple test runs, including from subdirectories."""
+ # Get all stats files, including those in subdirectories
+ stats_files = glob.glob(f"{reports_dir}/**/*_stats.txt", recursive=True)
+
+ results = {}
+ total_stats = {"tests": 0, "passed": 0, "failed": 0, "skipped": 0}
+
+ # Collect all slow tests across all test suites
+ all_slow_tests = []
+
+ # Process each stats file and its corresponding failures file
+ for stats_file in stats_files:
+ # Extract test suite name from filename (e.g., tests_pipeline_allegro_cuda_stats.txt -> pipeline_allegro_cuda)
+ base_name = os.path.basename(stats_file).replace("_stats.txt", "")
+
+ # Include parent directory in suite name if it's in a subdirectory
+ rel_path = os.path.relpath(os.path.dirname(stats_file), reports_dir)
+ if rel_path and rel_path != ".":
+ # Remove 'test_reports' suffix from directory name if present
+ dir_name = os.path.basename(rel_path)
+ if dir_name.endswith("_test_reports"):
+ dir_name = dir_name[:-13] # Remove '_test_reports' suffix
+ base_name = f"{dir_name}/{base_name}"
+
+ # Parse stats
+ stats = parse_stats_file(stats_file)
+
+ # If no slowest tests found in stats file, try the durations file directly
+ if not stats.get("slowest_tests"):
+ stats["slowest_tests"] = parse_durations_file(stats_file)
+
+ # Update total stats
+ for key in ["tests", "passed", "failed", "skipped"]:
+ total_stats[key] += stats[key]
+
+ # Collect slowest tests with their suite name
+ for slow_test in stats.get("slowest_tests", []):
+ all_slow_tests.append({"test": slow_test["test"], "duration": slow_test["duration"], "suite": base_name})
+
+ # Parse failures if there are any
+ failures = []
+ if stats["failed"] > 0:
+ # First try to get test paths from summary_short.txt which has the best format
+ summary_file = stats_file.replace("_stats.txt", "_summary_short.txt")
+ if os.path.exists(summary_file):
+ try:
+ with open(summary_file, "r") as f:
+ content = f.read()
+ # Look for full lines with test path and error message: "FAILED test_path - error_msg"
+ failed_test_lines = re.findall(
+ r"FAILED\s+(tests/[\w/]+\.py::[A-Za-z0-9_\.]+::[A-Za-z0-9_]+)(?:\s+-\s+(.+))?", content
+ )
+
+ if failed_test_lines:
+ for match in failed_test_lines:
+ test_path = match[0]
+ error_msg = match[1] if len(match) > 1 and match[1] else "No error message"
+
+ failures.append({"test": test_path, "error": error_msg})
+ except Exception as e:
+ print(f"Error parsing summary file: {e}")
+
+ # If no failures found in summary, try other failure files
+ if not failures:
+ failure_patterns = ["_failures_short.txt", "_failures.txt", "_failures_line.txt", "_failures_long.txt"]
+
+ for pattern in failure_patterns:
+ failures_file = stats_file.replace("_stats.txt", pattern)
+ if os.path.exists(failures_file):
+ failures = parse_failures_file(failures_file)
+ if failures:
+ break
+
+ # No debug output needed
+
+ # Store results for this test suite
+ results[base_name] = {"stats": stats, "failures": failures}
+
+ # Filter out entries with "secs were omitted"
+ filtered_slow_tests = [test for test in all_slow_tests if "secs were omitted" not in test["test"]]
+
+ # Sort all slow tests by duration (descending)
+ filtered_slow_tests.sort(key=lambda x: x["duration"], reverse=True)
+
+ # Get the number of slowest tests to show from environment variable or default to 10
+ num_slowest_tests = int(os.environ.get("SHOW_SLOWEST_TESTS", "10"))
+ top_slowest_tests = filtered_slow_tests[:num_slowest_tests] if filtered_slow_tests else []
+
+ # Calculate additional duration statistics
+ total_duration = sum(test["duration"] for test in all_slow_tests)
+
+ # Calculate duration per suite
+ suite_durations = {}
+ for test in all_slow_tests:
+ suite_name = test["suite"]
+ if suite_name not in suite_durations:
+ suite_durations[suite_name] = 0
+ suite_durations[suite_name] += test["duration"]
+
+ # Removed duration categories
+
+ return {
+ "total_stats": total_stats,
+ "test_suites": results,
+ "slowest_tests": top_slowest_tests,
+ "duration_stats": {"total_duration": total_duration, "suite_durations": suite_durations},
+ }
+
+
+def generate_report(consolidated_data):
+ """Generate a comprehensive markdown report from consolidated data."""
+ report = []
+
+ # Add report header
+ report.append("# Diffusers Nightly Test Report")
+ report.append(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
+
+ # Removed comparison section
+
+ # Add summary section
+ total = consolidated_data["total_stats"]
+ report.append("## Summary")
+
+ # Get duration stats if available
+ duration_stats = consolidated_data.get("duration_stats", {})
+ total_duration = duration_stats.get("total_duration", 0)
+
+ summary_table = [
+ ["Total Tests", total["tests"]],
+ ["Passed", total["passed"]],
+ ["Failed", total["failed"]],
+ ["Skipped", total["skipped"]],
+ ["Success Rate", f"{(total['passed'] / total['tests'] * 100):.2f}%" if total["tests"] > 0 else "N/A"],
+ ["Total Duration", f"{total_duration:.2f}s" if total_duration else "N/A"],
+ ]
+
+ report.append(tabulate(summary_table, tablefmt="pipe"))
+ report.append("")
+
+ # Removed duration distribution section
+
+ # Add test suites summary
+ report.append("## Test Suites")
+
+ # Include duration in test suites table if available
+ suite_durations = consolidated_data.get("duration_stats", {}).get("suite_durations", {})
+
+ if suite_durations:
+ suites_table = [["Test Suite", "Tests", "Passed", "Failed", "Skipped", "Success Rate", "Duration (s)"]]
+ else:
+ suites_table = [["Test Suite", "Tests", "Passed", "Failed", "Skipped", "Success Rate"]]
+
+ # Sort test suites by success rate (ascending - least successful first)
+ sorted_suites = sorted(
+ consolidated_data["test_suites"].items(),
+ key=lambda x: (x[1]["stats"]["passed"] / x[1]["stats"]["tests"] * 100) if x[1]["stats"]["tests"] > 0 else 0,
+ reverse=False,
+ )
+
+ for suite_name, suite_data in sorted_suites:
+ stats = suite_data["stats"]
+ success_rate = f"{(stats['passed'] / stats['tests'] * 100):.2f}%" if stats["tests"] > 0 else "N/A"
+
+ if suite_durations:
+ duration = suite_durations.get(suite_name, 0)
+ suites_table.append(
+ [
+ suite_name,
+ stats["tests"],
+ stats["passed"],
+ stats["failed"],
+ stats["skipped"],
+ success_rate,
+ f"{duration:.2f}",
+ ]
+ )
+ else:
+ suites_table.append(
+ [suite_name, stats["tests"], stats["passed"], stats["failed"], stats["skipped"], success_rate]
+ )
+
+ report.append(tabulate(suites_table, headers="firstrow", tablefmt="pipe"))
+ report.append("")
+
+ # Add slowest tests section
+ slowest_tests = consolidated_data.get("slowest_tests", [])
+ if slowest_tests:
+ report.append("## Slowest Tests")
+
+ slowest_table = [["Rank", "Test", "Duration (s)", "Test Suite"]]
+ for i, test in enumerate(slowest_tests, 1):
+ # Skip entries that don't contain actual test names
+ if "< 0.05 secs were omitted" in test["test"]:
+ continue
+ slowest_table.append([i, test["test"], f"{test['duration']:.2f}", test["suite"]])
+
+ report.append(tabulate(slowest_table, headers="firstrow", tablefmt="pipe"))
+ report.append("")
+
+ # Add failures section if there are any
+ failed_suites = [s for s in sorted_suites if s[1]["stats"]["failed"] > 0]
+
+ if failed_suites:
+ report.append("## Failures")
+
+ # Group failures by module for cleaner organization
+ failures_by_module = {}
+
+ for suite_name, suite_data in failed_suites:
+ # Extract failures data for this suite
+ for failure in suite_data.get("failures", []):
+ test_name = failure["test"]
+
+ # If test name doesn't look like a full path, try to reconstruct it
+ if not ("/" in test_name or "::" in test_name) and "." in test_name:
+ # For simple 'TestClass.test_method' format, try to get full path from suite name
+ # Form: tests__cuda -> tests//test_.py::TestClass::test_method
+ if suite_name.startswith("tests_") and "_cuda" in suite_name:
+ # Extract component name from suite
+ component = suite_name.replace("tests_", "").replace("_cuda", "")
+ if "." in test_name:
+ class_name, method_name = test_name.split(".", 1)
+ possible_path = f"tests/{component}/test_{component}.py::{class_name}::{method_name}"
+ # Use this constructed path if it seems reasonable
+ if "test_" in method_name:
+ test_name = possible_path
+
+ # Extract module name from test name
+ if "::" in test_name:
+ # For path/file.py::TestClass::test_method format
+ parts = test_name.split("::")
+ module_name = parts[-2] if len(parts) >= 2 else "Other" # TestClass
+ elif "." in test_name:
+ # For TestClass.test_method format
+ parts = test_name.split(".")
+ module_name = parts[0] # TestClass
+ else:
+ module_name = "Other"
+
+ # Skip module names that don't look like class/module names
+ if (
+ module_name.startswith(("e.g", "i.e", "etc"))
+ or module_name.replace(".", "").isdigit()
+ or len(module_name) < 3
+ ):
+ module_name = "Other"
+
+ # Add to the module group
+ if module_name not in failures_by_module:
+ failures_by_module[module_name] = []
+
+ # Prepend the suite name if the test name doesn't already have a full path
+ if "/" not in test_name and suite_name not in test_name:
+ full_test_name = f"{suite_name}::{test_name}"
+ else:
+ full_test_name = test_name
+
+ # Add this failure to the module group
+ failures_by_module[module_name].append(
+ {"test": full_test_name, "original_test": test_name, "error": failure["error"]}
+ )
+
+ # Create a list of failing tests for each module
+ if failures_by_module:
+ for module_name, failures in sorted(failures_by_module.items()):
+ report.append(f"### {module_name}")
+
+ # Put all failed tests in a single code block
+ report.append("```")
+ for failure in failures:
+ # Show test path and error message if available
+ if failure.get("error") and failure["error"] != "No error message":
+ report.append(f"{failure['test']} - {failure['error']}")
+ else:
+ report.append(failure["test"])
+ report.append("```")
+
+ report.append("") # Add space between modules
+ else:
+ report.append("*No detailed failure information available*")
+ report.append("")
+
+ return "\n".join(report)
+
+
+def create_test_groups_table(test_groups, total_tests, total_success_rate):
+ """Create a table-like format for test groups showing total tests and success rate."""
+ if not test_groups:
+ return None
+
+ # Sort by total test count (descending)
+ sorted_groups = sorted(test_groups.items(), key=lambda x: x[1]["total"], reverse=True)
+
+ # Create table lines
+ table_lines = ["```"]
+ table_lines.append("Test Results Summary")
+ table_lines.append("-------------------")
+ table_lines.append(f"Total Tests: {total_tests:,}")
+ table_lines.append(f"Success Rate: {total_success_rate}")
+ table_lines.append("")
+ table_lines.append("Category | Total Tests | Failed | Success Rate")
+ table_lines.append("------------------- | ----------- | ------ | ------------")
+
+ # Add rows
+ for category, stats in sorted_groups:
+ # Pad category name to fixed width (19 chars)
+ padded_cat = category[:19].ljust(19) # Truncate if too long
+ # Right-align counts
+ padded_total = str(stats["total"]).rjust(11)
+ padded_failed = str(stats["failed"]).rjust(6)
+ # Calculate and format success rate
+ if stats["total"] > 0:
+ cat_success_rate = f"{((stats['total'] - stats['failed']) / stats['total'] * 100):.1f}%"
+ else:
+ cat_success_rate = "N/A"
+ padded_rate = cat_success_rate.rjust(12)
+ table_lines.append(f"{padded_cat} | {padded_total} | {padded_failed} | {padded_rate}")
+
+ table_lines.append("```")
+
+ total_failures = sum(stats["failed"] for stats in test_groups.values())
+ return (
+ f"*Test Groups Summary ({total_failures} {'failure' if total_failures == 1 else 'failures'}):*\n"
+ + "\n".join(table_lines)
+ )
+
+
+def create_slack_payload(consolidated_data):
+ """Create a concise Slack message payload from consolidated data."""
+ total = consolidated_data["total_stats"]
+ success_rate = f"{(total['passed'] / total['tests'] * 100):.2f}%" if total["tests"] > 0 else "N/A"
+
+ # Determine emoji based on success rate
+ if total["failed"] == 0:
+ emoji = "✅"
+ elif total["failed"] / total["tests"] < 0.1:
+ emoji = "⚠️"
+ else:
+ emoji = "❌"
+
+ # Create a more compact summary section
+ summary = f"{emoji} *Diffusers Nightly Tests:* {success_rate} success ({total['passed']}/{total['tests']} tests"
+ if total["skipped"] > 0:
+ summary += f", {total['skipped']} skipped"
+ summary += ")"
+
+ # Create the test suites table in markdown format
+ # Build the markdown table with proper alignment
+ table_lines = []
+ table_lines.append("```")
+
+ # Sort test suites by success rate (ascending - least successful first)
+ sorted_suites = sorted(
+ consolidated_data["test_suites"].items(),
+ key=lambda x: (x[1]["stats"]["passed"] / x[1]["stats"]["tests"] * 100) if x[1]["stats"]["tests"] > 0 else 0,
+ reverse=False,
+ )
+
+ # Calculate max widths for proper alignment
+ max_suite_name_len = max(len(suite_name) for suite_name, _ in sorted_suites) if sorted_suites else 10
+ max_suite_name_len = max(max_suite_name_len, len("Test Suite")) # Ensure header fits
+
+ # Create header with proper spacing (only Tests, Failed, Success Rate)
+ header = f"| {'Test Suite'.ljust(max_suite_name_len)} | {'Tests'.rjust(6)} | {'Failed'.rjust(6)} | {'Success Rate'.ljust(12)} |"
+ separator = f"|:{'-' * max_suite_name_len}|{'-' * 7}:|{'-' * 7}:|:{'-' * 11}|"
+
+ table_lines.append(header)
+ table_lines.append(separator)
+
+ # Add data rows with proper alignment
+ for suite_name, suite_data in sorted_suites:
+ stats = suite_data["stats"]
+ suite_success_rate = f"{(stats['passed'] / stats['tests'] * 100):.2f}%" if stats["tests"] > 0 else "N/A"
+
+ row = f"| {suite_name.ljust(max_suite_name_len)} | {str(stats['tests']).rjust(6)} | {str(stats['failed']).rjust(6)} | {suite_success_rate.ljust(12)} |"
+
+ table_lines.append(row)
+
+ table_lines.append("```")
+
+ # Create the Slack payload with character limit enforcement
+ payload = [
+ {"type": "section", "text": {"type": "mrkdwn", "text": summary}},
+ {"type": "section", "text": {"type": "mrkdwn", "text": "\n".join(table_lines)}},
+ ]
+
+ # Add action button
+ if os.environ.get("GITHUB_RUN_ID"):
+ run_id = os.environ["GITHUB_RUN_ID"]
+ payload.append(
+ {
+ "type": "section",
+ "text": {
+ "type": "mrkdwn",
+ "text": f"**",
+ },
+ }
+ )
+
+ # Add date in more compact form
+ payload.append(
+ {
+ "type": "context",
+ "elements": [
+ {
+ "type": "plain_text",
+ "text": f"Results for {date.today()}",
+ },
+ ],
+ }
+ )
+
+ # Enforce 3001 character limit
+ payload_text = str(payload)
+ if len(payload_text) > MAX_LEN_MESSAGE:
+ # Truncate table if payload is too long
+ # Remove rows from the bottom until under limit
+ original_table_lines = table_lines[:]
+ while len(str(payload)) > MAX_LEN_MESSAGE and len(table_lines) > 3: # Keep at least header and separator
+ # Remove the last data row (but keep ``` at the end)
+ table_lines.pop(-2) # Remove second to last (last is the closing ```)
+
+ # Recreate payload with truncated table
+ payload[1] = {"type": "section", "text": {"type": "mrkdwn", "text": "\n".join(table_lines)}}
+
+ # Add note if we had to truncate
+ if len(table_lines) < len(original_table_lines):
+ truncated_count = len(original_table_lines) - len(table_lines)
+ table_lines.insert(-1, f"... {truncated_count} more test suites (truncated due to message limit)")
+ payload[1] = {"type": "section", "text": {"type": "mrkdwn", "text": "\n".join(table_lines)}}
+
+ return payload
+
+
+def create_failed_tests_by_suite_ordered(consolidated_data):
+ """Group failed tests by test suite, ordered by success rate (ascending)."""
+ # Sort test suites by success rate (ascending - least successful first)
+ sorted_suites = sorted(
+ consolidated_data["test_suites"].items(),
+ key=lambda x: (x[1]["stats"]["passed"] / x[1]["stats"]["tests"] * 100) if x[1]["stats"]["tests"] > 0 else 0,
+ reverse=False,
+ )
+
+ failed_suite_tests = []
+
+ # Process suites in order of success rate
+ for suite_name, suite_data in sorted_suites:
+ if suite_data["stats"]["failed"] > 0:
+ suite_failures = []
+
+ for failure in suite_data.get("failures", []):
+ test_name = failure["test"]
+
+ # Try to reconstruct full path if partial
+ if "::" in test_name and "/" in test_name:
+ full_test_name = test_name
+ elif "::" in test_name or "." in test_name:
+ if "/" not in test_name and suite_name not in test_name:
+ full_test_name = f"{suite_name}::{test_name}"
+ else:
+ full_test_name = test_name
+ else:
+ full_test_name = f"{suite_name}::{test_name}"
+
+ suite_failures.append(full_test_name)
+
+ # Sort and deduplicate tests within the suite
+ suite_failures = sorted(set(suite_failures))
+
+ if suite_failures:
+ failed_suite_tests.append(
+ {
+ "suite_name": suite_name,
+ "tests": suite_failures,
+ "success_rate": (suite_data["stats"]["passed"] / suite_data["stats"]["tests"] * 100)
+ if suite_data["stats"]["tests"] > 0
+ else 0,
+ }
+ )
+
+ return failed_suite_tests
+
+
+def main(args):
+ # Make sure reports directory exists
+ if not os.path.isdir(args.reports_dir):
+ print(f"Error: Reports directory '{args.reports_dir}' does not exist.")
+ return
+
+ # Consolidate reports
+ consolidated_data = consolidate_reports(args.reports_dir)
+
+ # Check if we found any test results
+ if consolidated_data["total_stats"]["tests"] == 0:
+ print(f"Warning: No test results found in '{args.reports_dir}' or its subdirectories.")
+
+ # Generate markdown report
+ report = generate_report(consolidated_data)
+
+ # Save report to file if specified
+ if args.output_file:
+ # Create parent directories if they don't exist
+ output_dir = os.path.dirname(args.output_file)
+ if output_dir and not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ with open(args.output_file, "w") as f:
+ f.write(report)
+
+ # Only print the report when saving to file
+ print(report)
+
+ # Send to Slack if token is available (optional, can be disabled)
+ slack_token = os.environ.get("SLACK_API_TOKEN")
+ if slack_token and args.slack_channel_name:
+ payload = create_slack_payload(consolidated_data)
+
+ try:
+ client = WebClient(token=slack_token)
+ # Send main message
+ response = client.chat_postMessage(channel=f"#{args.slack_channel_name}", blocks=payload)
+ print(f"Report sent to Slack channel: {args.slack_channel_name}")
+
+ # Send failed tests as separate threaded replies grouped by test suite (ordered by success rate)
+ total = consolidated_data["total_stats"]
+ if total["failed"] > 0:
+ failed_suites = create_failed_tests_by_suite_ordered(consolidated_data)
+ for suite_info in failed_suites:
+ suite_name = suite_info["suite_name"]
+ suite_tests = suite_info["tests"]
+ success_rate = suite_info["success_rate"]
+ message_text = (
+ f"**{suite_name}** (Success Rate: {success_rate:.2f}%)\n```\n"
+ + "\n".join(suite_tests)
+ + "\n```"
+ )
+ client.chat_postMessage(
+ channel=f"#{args.slack_channel_name}",
+ thread_ts=response["ts"], # Reply in thread
+ text=message_text, # Use text instead of blocks for markdown
+ )
+ print(f"Failed tests details sent as {len(failed_suites)} thread replies")
+ except Exception as e:
+ print(f"Error sending report to Slack: {e}")
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ main(args)
diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py
index 791df0e78694..cc3bccb9bd63 100644
--- a/utils/custom_init_isort.py
+++ b/utils/custom_init_isort.py
@@ -252,7 +252,7 @@ def sort_imports(file: str, check_only: bool = True):
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
)
- # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt).
+ # We ignore block 0 (everything until start_prompt) and the last block (everything after end_prompt).
for block_idx in range(1, len(main_blocks) - 1):
# Check if the block contains some `_import_structure`s thingy to sort.
block = main_blocks[block_idx]
diff --git a/utils/fetch_latest_release_branch.py b/utils/fetch_latest_release_branch.py
index 9bf578a5f58e..5b0be6253e1b 100644
--- a/utils/fetch_latest_release_branch.py
+++ b/utils/fetch_latest_release_branch.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+
import requests
from packaging.version import parse
@@ -27,7 +28,11 @@ def fetch_all_branches(user, repo):
page = 1 # Start from first page
while True:
# Make a request to the GitHub API for the branches
- response = requests.get(f"https://api.github.com/repos/{user}/{repo}/branches", params={"page": page})
+ response = requests.get(
+ f"https://api.github.com/repos/{user}/{repo}/branches",
+ params={"page": page},
+ timeout=60,
+ )
# Check if the request was successful
if response.status_code == 200:
diff --git a/utils/log_reports.py b/utils/log_reports.py
index dd1b258519d7..5575c9ba8415 100644
--- a/utils/log_reports.py
+++ b/utils/log_reports.py
@@ -35,7 +35,7 @@ def main(slack_channel_name=None):
if line.get("nodeid", "") != "":
test = line["nodeid"]
if line.get("duration", None) is not None:
- duration = f'{line["duration"]:.4f}'
+ duration = f"{line['duration']:.4f}"
if line.get("outcome", "") == "failed":
section_num_failed += 1
failed.append([test, duration, log.name.split("_")[0]])
diff --git a/utils/notify_benchmarking_status.py b/utils/notify_benchmarking_status.py
index c9c6ab485f59..8a426a15b5ed 100644
--- a/utils/notify_benchmarking_status.py
+++ b/utils/notify_benchmarking_status.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/notify_community_pipelines_mirror.py b/utils/notify_community_pipelines_mirror.py
index a7d3a31c988e..2981f008501f 100644
--- a/utils/notify_community_pipelines_mirror.py
+++ b/utils/notify_community_pipelines_mirror.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/notify_slack_about_release.py b/utils/notify_slack_about_release.py
index a67dd8bd0685..a68182f8174c 100644
--- a/utils/notify_slack_about_release.py
+++ b/utils/notify_slack_about_release.py
@@ -1,5 +1,5 @@
# coding=utf-8
-# Copyright 2024 The HuggingFace Team. All rights reserved.
+# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,7 +26,7 @@
def check_pypi_for_latest_release(library_name):
"""Check PyPI for the latest release of the library."""
- response = requests.get(f"https://pypi.org/pypi/{library_name}/json")
+ response = requests.get(f"https://pypi.org/pypi/{library_name}/json", timeout=60)
if response.status_code == 200:
data = response.json()
return data["info"]["version"]
@@ -38,7 +38,7 @@ def check_pypi_for_latest_release(library_name):
def get_github_release_info(github_repo):
"""Fetch the latest release info from GitHub."""
url = f"https://api.github.com/repos/{github_repo}/releases/latest"
- response = requests.get(url)
+ response = requests.get(url, timeout=60)
if response.status_code == 200:
data = response.json()
diff --git a/utils/print_env.py b/utils/print_env.py
index 0a1cfbef133f..2fe0777daf7d 100644
--- a/utils/print_env.py
+++ b/utils/print_env.py
@@ -28,19 +28,40 @@
print("OS platform:", platform.platform())
print("OS architecture:", platform.machine())
+try:
+ import psutil
+
+ vm = psutil.virtual_memory()
+ total_gb = vm.total / (1024**3)
+ available_gb = vm.available / (1024**3)
+ print(f"Total RAM: {total_gb:.2f} GB")
+ print(f"Available RAM: {available_gb:.2f} GB")
+except ImportError:
+ pass
try:
import torch
print("Torch version:", torch.__version__)
print("Cuda available:", torch.cuda.is_available())
- print("Cuda version:", torch.version.cuda)
- print("CuDNN version:", torch.backends.cudnn.version())
- print("Number of GPUs available:", torch.cuda.device_count())
if torch.cuda.is_available():
+ print("Cuda version:", torch.version.cuda)
+ print("CuDNN version:", torch.backends.cudnn.version())
+ print("Number of GPUs available:", torch.cuda.device_count())
device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
print(f"CUDA memory: {total_memory} GB")
+
+ print("XPU available:", hasattr(torch, "xpu") and torch.xpu.is_available())
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
+ print("XPU model:", torch.xpu.get_device_properties(0).name)
+ print("XPU compiler version:", torch.version.xpu)
+ print("Number of XPUs available:", torch.xpu.device_count())
+ device_properties = torch.xpu.get_device_properties(0)
+ total_memory = device_properties.total_memory / (1024**3)
+ print(f"XPU memory: {total_memory} GB")
+
+
except ImportError:
print("Torch version:", None)
diff --git a/utils/stale.py b/utils/stale.py
index 20cb6cabeb91..b92fb83ceb4c 100644
--- a/utils/stale.py
+++ b/utils/stale.py
@@ -1,4 +1,4 @@
-# Copyright 2024 The HuggingFace Team, the AllenNLP library authors. All rights reserved.
+# Copyright 2025 The HuggingFace Team, the AllenNLP library authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/utils/update_metadata.py b/utils/update_metadata.py
index a97e65801c5f..4fde581d4170 100644
--- a/utils/update_metadata.py
+++ b/utils/update_metadata.py
@@ -104,8 +104,7 @@ def update_metadata(commit_sha: str):
if commit_sha is not None:
commit_message = (
- f"Update with commit {commit_sha}\n\nSee: "
- f"https://github.com/huggingface/diffusers/commit/{commit_sha}"
+ f"Update with commit {commit_sha}\n\nSee: https://github.com/huggingface/diffusers/commit/{commit_sha}"
)
else:
commit_message = "Update"